diff --git a/.github/workflows/aes.yml b/.github/workflows/aes.yml index f9edc7cb..2f14f21f 100644 --- a/.github/workflows/aes.yml +++ b/.github/workflows/aes.yml @@ -16,6 +16,7 @@ defaults: env: CARGO_INCREMENTAL: 0 RUSTFLAGS: "-Dwarnings" + QEMU_SRC_VERSION: 8.2.0 jobs: # Builds for no_std platforms @@ -234,6 +235,143 @@ jobs: - run: cross test --package aes --target ${{ matrix.target }} --features hazmat - run: cross test --package aes --target ${{ matrix.target }} --all-features + # Build and cache latest QEMUs; needed for RVV and SVE features + qemu-build-and-cache: + runs-on: ubuntu-latest + defaults: + run: + working-directory: /home/runner + steps: + - id: cache-qemu + uses: actions/cache@v3 + with: + path: /opt/qemu-${{ env.QEMU_SRC_VERSION }} + key: ${{ runner.os }}-qemu-${{ env.QEMU_SRC_VERSION }} + - if: ${{ steps.cache-qemu.outputs.cache-hit != 'true' }} + run: | + sudo apt update + DEBIAN_FRONTEND=noninteractive sudo apt --assume-yes install \ + build-essential \ + curl \ + gnupg \ + libglib2.0-dev \ + ninja-build \ + pkg-config \ + python3-venv \ + xz-utils + - if: ${{ steps.cache-qemu.outputs.cache-hit != 'true' }} + run: | + mkdir -p vendor + cd vendor + curl -JLO https://download.qemu.org/qemu-${QEMU_SRC_VERSION}.tar.xz.sig + curl -JLO https://download.qemu.org/qemu-${QEMU_SRC_VERSION}.tar.xz + gpg --keyserver hkps://keys.openpgp.org --recv-keys CEACC9E15534EBABB82D3FA03353C9CEF108B584 + gpg --verify qemu-${QEMU_SRC_VERSION}.tar.xz.sig qemu-${QEMU_SRC_VERSION}.tar.xz + tar xvf qemu-${QEMU_SRC_VERSION}.tar.xz + - if: ${{ steps.cache-qemu.outputs.cache-hit != 'true' }} + run: | + cd vendor/qemu-${QEMU_SRC_VERSION} + ./configure \ + --prefix=/opt/qemu-${QEMU_SRC_VERSION} \ + --without-default-features \ + --without-default-devices \ + --disable-system \ + --target-list=aarch64-linux-user,riscv32-linux-user,riscv64-linux-user \ + --static + make -j + make install + + # ARMv9 cross-compiled tests for SVE2-AES + armv9: + needs: qemu-build-and-cache + strategy: + matrix: + include: + - target: aarch64-unknown-linux-gnu + rust: 1.72.0 # MSRV + runs-on: ubuntu-latest + env: + LLVM_MAJOR_VERSION: 17 + steps: + - uses: actions/checkout@v3 + - id: discover-ubuntu-codename + shell: bash + run: echo "codename=$(lsb_release -cs)" >> $GITHUB_OUTPUT + - id: cache-qemu + uses: actions/cache@v3 + with: + path: /opt/qemu-${{ env.QEMU_SRC_VERSION }} + key: ${{ runner.os }}-qemu-${{ env.QEMU_SRC_VERSION }} + - run: echo "/opt/qemu-${QEMU_SRC_VERSION}/bin" >> $GITHUB_PATH + - run: | + sudo apt update + DEBIAN_FRONTEND=noninteractive sudo apt install --assume-yes \ + curl \ + gnupg + - run: sudo dpkg --add-architecture arm64 + - run: sudo sed -i'' -E 's/^(deb|deb-src) /\1 [arch=amd64] /' /etc/apt/sources.list + - run: | + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }} main restricted" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-updates main restricted" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }} universe" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-updates universe" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }} multiverse" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-updates multiverse" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-backports main restricted universe multiverse" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-security main restricted" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-security universe" | sudo tee --append /etc/apt/sources.list.d/arm64.list + echo "deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ ${{ steps.discover-ubuntu-codename.outputs.codename }}-security multiverse" | sudo tee --append /etc/apt/sources.list.d/arm64.list + - run: | + curl -JL https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - + echo "deb [arch=amd64] http://apt.llvm.org/${{ steps.discover-ubuntu-codename.outputs.codename }}/ llvm-toolchain-${{ steps.discover-ubuntu-codename.outputs.codename }}-${LLVM_MAJOR_VERSION} main" | sudo tee --append /etc/apt/sources.list.d/llvm.list + echo "deb-src [arch=amd64] http://apt.llvm.org/${{ steps.discover-ubuntu-codename.outputs.codename }}/ llvm-toolchain-${{ steps.discover-ubuntu-codename.outputs.codename }}-${LLVM_MAJOR_VERSION} main" | sudo tee --append /etc/apt/sources.list.d/llvm.list + - shell: bash + run: | + sudo apt update + DEBIAN_FRONTEND=noninteractive sudo apt install --assume-yes \ + binfmt-support build-essential clang-${LLVM_MAJOR_VERSION} clang-tools-${LLVM_MAJOR_VERSION} lld-${LLVM_MAJOR_VERSION} \ + libc6:{amd64,arm64} \ + libc6-dev:{amd64,arm64} \ + libgcc-12-dev:{amd64,arm64} \ + libgcc-s1:{amd64,arm64} \ + libstdc++-12-dev:{amd64,arm64} \ + linux-libc-dev:{amd64,arm64} \ + libglib2.0-0 + - uses: RustCrypto/actions/cargo-cache@master + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + - name: write .cargo/config.toml + shell: bash + run: | + cd ../aes/.. + mkdir -p .cargo + echo '[target.aarch64-unknown-linux-gnu]' >> .cargo/config.toml + echo 'runner = "qemu-aarch64"' >> .cargo/config.toml + echo 'linker = "clang-17"' >> .cargo/config.toml + echo 'rustflags = [' >> .cargo/config.toml + echo ' "-C", "link-arg=-fuse-ld=lld-17",' >> .cargo/config.toml + echo ' "-C", "link-arg=-march=arm9-a+sve+sve2+sve2+sve2-aes",' >> .cargo/config.toml + echo ' "-C", "link-arg=--target=aarch64-unknown-linux-gnu",' >> .cargo/config.toml + echo ' "-C", "target-feature=+sve,+sve2,+sve2-aes"' >> .cargo/config.toml + echo ']' >> .cargo/config.toml + - name: cargo test --package aes + run: | + unset RUSTFLAGS + QEMU_CPU="max" cargo test --package aes --target ${{ matrix.target }} + QEMU_CPU="max,sve256=off" cargo test --package aes --target ${{ matrix.target }} + - name: cargo test --package aes --features hazmat + run: | + unset RUSTFLAGS + QEMU_CPU="max" cargo test --package aes --target ${{ matrix.target }} --features hazmat + QEMU_CPU="max,sve256=off" cargo test --package aes --target ${{ matrix.target }} --features hazmat + - name: cargo test --package aes --all-features + run: | + unset RUSTFLAGS + QEMU_CPU="max" cargo test --package aes --target ${{ matrix.target }} --all-features + QEMU_CPU="max,sve256=off" cargo test --package aes --target ${{ matrix.target }} --all-features + clippy: env: RUSTFLAGS: "-Dwarnings --cfg aes_compact" diff --git a/aes/src/armv9.rs b/aes/src/armv9.rs new file mode 100644 index 00000000..46f837ac --- /dev/null +++ b/aes/src/armv9.rs @@ -0,0 +1,334 @@ +//! AES block cipher implementation using the ARMv9 SVE2-AES feature. +//! +//! NOTE: The key-schedule routines currently reuse the ARMv8 NEON implementation but the intention +//! is to replace them with a full SVE implementation. +//! +//! NOTE: The rest of the cipher implementation is not based on an existing implementation but is +//! generally adapted from the similar implementation written for this crate targeting RISC-V RVV. + +mod encdec; +mod expand; +#[cfg(test)] +mod test_expand; + +// TODO(silvanshade): +// - implement key-schedule using sve +// - interleave key-schedule for par blocks (unroll loop and use more vector registers) +// - hazmat +// - benchmarks vs neon + +use crate::{Block, Block8}; +use cipher::{ + consts::{U16, U24, U32, U8}, + inout::InOut, + AlgorithmName, BlockBackend, BlockCipher, BlockCipherDecrypt, BlockCipherEncrypt, BlockClosure, + BlockSizeUser, Key, KeyInit, KeySizeUser, ParBlocksSizeUser, +}; +use core::{arch::aarch64::uint8x16_t, fmt}; + +type RoundKey = uint8x16_t; +type RoundKeys = [RoundKey; N]; + +macro_rules! define_aes_impl { + ( + $module:ident, + $name:ident, + $name_enc:ident, + $name_dec:ident, + $name_back_enc:ident, + $name_back_dec:ident, + $key_size:ty, + $words:tt, + $rounds:tt, + $doc:expr $(,)? + ) => { + #[doc=$doc] + #[doc = "block cipher"] + #[derive(Clone)] + pub struct $name { + encrypt: $name_enc, + decrypt: $name_dec, + } + + impl BlockCipher for $name {} + + impl KeySizeUser for $name { + type KeySize = $key_size; + } + + impl KeyInit for $name { + #[inline] + fn new(key: &Key) -> Self { + let encrypt = $name_enc::new(key); + let decrypt = $name_dec::from(&encrypt); + Self { encrypt, decrypt } + } + } + + impl From<$name_enc> for $name { + #[inline] + fn from(encrypt: $name_enc) -> $name { + let decrypt = (&encrypt).into(); + Self { encrypt, decrypt } + } + } + + impl From<&$name_enc> for $name { + #[inline] + fn from(encrypt: &$name_enc) -> $name { + let decrypt = encrypt.into(); + let encrypt = encrypt.clone(); + Self { encrypt, decrypt } + } + } + + impl BlockSizeUser for $name { + type BlockSize = U16; + } + + impl BlockCipherEncrypt for $name { + fn encrypt_with_backend(&self, f: impl BlockClosure) { + self.encrypt.encrypt_with_backend(f) + } + } + + impl BlockCipherDecrypt for $name { + fn decrypt_with_backend(&self, f: impl BlockClosure) { + self.decrypt.decrypt_with_backend(f) + } + } + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name), " { .. }")) + } + } + + impl AlgorithmName for $name { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name)) + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name {} + + #[doc=$doc] + #[doc = "block cipher (encrypt-only)"] + #[derive(Clone)] + pub struct $name_enc { + round_keys: RoundKeys<$rounds>, + } + + impl $name_enc { + #[inline(always)] + pub(crate) fn get_enc_backend(&self) -> $name_back_enc<'_> { + $name_back_enc(self) + } + } + + impl BlockCipher for $name_enc {} + + impl KeySizeUser for $name_enc { + type KeySize = $key_size; + } + + impl KeyInit for $name_enc { + #[inline] + fn new(key: &Key) -> Self { + Self { + round_keys: self::expand::$module::expand_key(key), + } + } + } + + impl BlockSizeUser for $name_enc { + type BlockSize = U16; + } + + impl BlockCipherEncrypt for $name_enc { + fn encrypt_with_backend(&self, f: impl BlockClosure) { + f.call(&mut self.get_enc_backend()) + } + } + + impl fmt::Debug for $name_enc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name_enc), " { .. }")) + } + } + + impl AlgorithmName for $name_enc { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name_enc)) + } + } + + impl Drop for $name_enc { + #[inline] + fn drop(&mut self) { + #[cfg(feature = "zeroize")] + zeroize::Zeroize::zeroize(&mut self.round_keys); + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name_enc {} + + #[doc=$doc] + #[doc = "block cipher (decrypt-only)"] + #[derive(Clone)] + pub struct $name_dec { + round_keys: RoundKeys<$rounds>, + } + + impl $name_dec { + #[inline(always)] + pub(crate) fn get_dec_backend(&self) -> $name_back_dec<'_> { + $name_back_dec(self) + } + } + + impl BlockCipher for $name_dec {} + + impl KeySizeUser for $name_dec { + type KeySize = $key_size; + } + + impl KeyInit for $name_dec { + #[inline] + fn new(key: &Key) -> Self { + $name_enc::new(key).into() + } + } + + impl From<$name_enc> for $name_dec { + #[inline] + fn from(enc: $name_enc) -> $name_dec { + Self::from(&enc) + } + } + + impl From<&$name_enc> for $name_dec { + fn from(enc: &$name_enc) -> $name_dec { + let mut round_keys = enc.round_keys; + self::expand::$module::inv_expanded_keys(&mut round_keys); + Self { round_keys } + } + } + + impl BlockSizeUser for $name_dec { + type BlockSize = U16; + } + + impl BlockCipherDecrypt for $name_dec { + fn decrypt_with_backend(&self, f: impl BlockClosure) { + f.call(&mut self.get_dec_backend()); + } + } + + impl fmt::Debug for $name_dec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name_dec), " { .. }")) + } + } + + impl AlgorithmName for $name_dec { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name_dec)) + } + } + + impl Drop for $name_dec { + #[inline] + fn drop(&mut self) { + #[cfg(feature = "zeroize")] + zeroize::Zeroize::zeroize(&mut self.round_keys); + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name_dec {} + + pub(crate) struct $name_back_enc<'a>(&'a $name_enc); + + impl<'a> BlockSizeUser for $name_back_enc<'a> { + type BlockSize = U16; + } + + impl<'a> ParBlocksSizeUser for $name_back_enc<'a> { + type ParBlocksSize = U8; + } + + impl<'a> BlockBackend for $name_back_enc<'a> { + #[inline(always)] + fn proc_block(&mut self, block: InOut<'_, '_, Block>) { + self::encdec::$module::encrypt1(&self.0.round_keys, block); + } + + #[inline(always)] + fn proc_par_blocks(&mut self, blocks: InOut<'_, '_, Block8>) { + self::encdec::$module::encrypt8(&self.0.round_keys, blocks) + } + } + + pub(crate) struct $name_back_dec<'a>(&'a $name_dec); + + impl<'a> BlockSizeUser for $name_back_dec<'a> { + type BlockSize = U16; + } + + impl<'a> ParBlocksSizeUser for $name_back_dec<'a> { + type ParBlocksSize = U8; + } + + impl<'a> BlockBackend for $name_back_dec<'a> { + #[inline(always)] + fn proc_block(&mut self, block: InOut<'_, '_, Block>) { + self::encdec::$module::decrypt1(&self.0.round_keys, block); + } + + #[inline(always)] + fn proc_par_blocks(&mut self, blocks: InOut<'_, '_, Block8>) { + self::encdec::$module::decrypt8(&self.0.round_keys, blocks) + } + } + }; +} + +define_aes_impl!( + aes128, + Aes128, + Aes128Enc, + Aes128Dec, + Aes128BackEnc, + Aes128BackDec, + U16, + 2, + 11, + "AES-128", +); +define_aes_impl!( + aes192, + Aes192, + Aes192Enc, + Aes192Dec, + Aes192BackEnc, + Aes192BackDec, + U24, + 3, + 13, + "AES-192", +); +define_aes_impl!( + aes256, + Aes256, + Aes256Enc, + Aes256Dec, + Aes256BackEnc, + Aes256BackDec, + U32, + 4, + 15, + "AES-256", +); diff --git a/aes/src/armv9/encdec.rs b/aes/src/armv9/encdec.rs new file mode 100644 index 00000000..98071f66 --- /dev/null +++ b/aes/src/armv9/encdec.rs @@ -0,0 +1,5 @@ +use super::RoundKeys; + +pub(super) mod aes128; +pub(super) mod aes192; +pub(super) mod aes256; diff --git a/aes/src/armv9/encdec/aes128.rs b/aes/src/armv9/encdec/aes128.rs new file mode 100644 index 00000000..07b6f2a2 --- /dev/null +++ b/aes/src/armv9/encdec/aes128.rs @@ -0,0 +1,178 @@ +use super::RoundKeys; +use crate::{Block, Block8}; +use cipher::inout::InOut; +use core::arch::global_asm; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + ".balign 8", // align section to 8 bytes + ".global aes_armv9_encdec_aes128_encrypt", // declare symbol + ".type aes_armv9_encdec_aes128_encrypt, %function", // declare symbol as function type + "aes_armv9_encdec_aes128_encrypt:", // start function + "mov x8, #0", // set loop counter {x8} to 0 + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.none 2f", // exit if avl == 0 + + "ld1rqb {{ z0.b}}, p0/z, [x3, #0 ]", // broadcast load round 00 key + "ld1rqb {{ z1.b}}, p0/z, [x3, #16]", // broadcast load round 01 key + "ld1rqb {{ z2.b}}, p0/z, [x3, #32]", // broadcast load round 02 key + "ld1rqb {{ z3.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 03 key; increment key pointer by 4 indices + "ld1rqb {{ z4.b}}, p0/z, [x3, #0 ]", // broadcast load round 04 key + "ld1rqb {{ z5.b}}, p0/z, [x3, #16]", // broadcast load round 05 key + "ld1rqb {{ z6.b}}, p0/z, [x3, #32]", // broadcast load round 06 key + "ld1rqb {{ z7.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 07 key; increment key pointer by 4 indices + "ld1rqb {{ z8.b}}, p0/z, [x3, #0 ]", // broadcast load round 08 key + "ld1rqb {{ z9.b}}, p0/z, [x3, #16]", // broadcast load round 09 key + "ld1rqb {{z10.b}}, p0/z, [x3, #32]", // broadcast load round 10 key + "1:", + "ld1b z31.b, p0/z, [x1]", // load avl bytes of plain-data + + "aese z31.b, z31.b, z0.b", "aesmc z31.b, z31.b", // perform AES-128 round 00 encryption + "aese z31.b, z31.b, z1.b", "aesmc z31.b, z31.b", // perform AES-128 round 01 encryption + "aese z31.b, z31.b, z2.b", "aesmc z31.b, z31.b", // perform AES-128 round 02 encryption + "aese z31.b, z31.b, z3.b", "aesmc z31.b, z31.b", // perform AES-128 round 03 encryption + "aese z31.b, z31.b, z4.b", "aesmc z31.b, z31.b", // perform AES-128 round 04 encryption + "aese z31.b, z31.b, z5.b", "aesmc z31.b, z31.b", // perform AES-128 round 05 encryption + "aese z31.b, z31.b, z6.b", "aesmc z31.b, z31.b", // perform AES-128 round 06 encryption + "aese z31.b, z31.b, z7.b", "aesmc z31.b, z31.b", // perform AES-128 round 07 encryption + "aese z31.b, z31.b, z8.b", "aesmc z31.b, z31.b", // perform AES-128 round 08 encryption + "aese z31.b, z31.b, z9.b", // perform AES-128 round 09 encryption + "eor z31.b, z31.b, z10.b", // perform AES-128 round 10 encryption + + "st1b z31.b, p0, [x0]", // save avl bytes of cipher-data + + "incb x0", "incb x1", "incb x8", // increment plain-data pointer, cipher-data pointer, loop counter by avl indices + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.first 1b", // loop if (0 < avl) + "2:", + "ret", +} +extern "C" { + pub fn aes_armv9_encdec_aes128_encrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u8, + ); +} + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + ".balign 8", // align section to 8 bytes + ".global aes_armv9_encdec_aes128_decrypt", // declare symbol + ".type aes_armv9_encdec_aes128_decrypt, %function", // declare symbol as function type + "aes_armv9_encdec_aes128_decrypt:", // start function + "mov x8, #0", // set loop counter {x8} to 0 + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.none 2f", // exit if avl == 0 + + "ld1rqb {{ z0.b}}, p0/z, [x3, #0 ]", // broadcast load round 00 key + "ld1rqb {{ z1.b}}, p0/z, [x3, #16]", // broadcast load round 01 key + "ld1rqb {{ z2.b}}, p0/z, [x3, #32]", // broadcast load round 02 key + "ld1rqb {{ z3.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 03 key; increment key pointer by 4 indices + "ld1rqb {{ z4.b}}, p0/z, [x3, #0 ]", // broadcast load round 04 key + "ld1rqb {{ z5.b}}, p0/z, [x3, #16]", // broadcast load round 05 key + "ld1rqb {{ z6.b}}, p0/z, [x3, #32]", // broadcast load round 06 key + "ld1rqb {{ z7.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 07 key; increment key pointer by 4 indices + "ld1rqb {{ z8.b}}, p0/z, [x3, #0 ]", // broadcast load round 08 key + "ld1rqb {{ z9.b}}, p0/z, [x3, #16]", // broadcast load round 09 key + "ld1rqb {{z10.b}}, p0/z, [x3, #32]", // broadcast load round 10 key + "1:", + "ld1b z31.b, p0/z, [x1]", // load vl bytes of cipher-data + + "aesd z31.b, z31.b, z10.b", "aesimc z31.b, z31.b", // perform AES-128 round 10 decryption + "aesd z31.b, z31.b, z9.b", "aesimc z31.b, z31.b", // perform AES-128 round 09 decryption + "aesd z31.b, z31.b, z8.b", "aesimc z31.b, z31.b", // perform AES-128 round 08 decryption + "aesd z31.b, z31.b, z7.b", "aesimc z31.b, z31.b", // perform AES-128 round 07 decryption + "aesd z31.b, z31.b, z6.b", "aesimc z31.b, z31.b", // perform AES-128 round 06 decryption + "aesd z31.b, z31.b, z5.b", "aesimc z31.b, z31.b", // perform AES-128 round 05 decryption + "aesd z31.b, z31.b, z4.b", "aesimc z31.b, z31.b", // perform AES-128 round 04 decryption + "aesd z31.b, z31.b, z3.b", "aesimc z31.b, z31.b", // perform AES-128 round 03 decryption + "aesd z31.b, z31.b, z2.b", "aesimc z31.b, z31.b", // perform AES-128 round 02 decryption + "aesd z31.b, z31.b, z1.b", // perform AES-128 round 01 decryption + "eor z31.b, z31.b, z0.b", // perform AES-128 round 00 decryption + + "st1b z31.b, p0, [x0]", // save vl bytes of plain-data + + "incb x0", "incb x1", "incb x8", // increment plain-data pointer, cipher-data pointer, loop counter by avl indices + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.first 1b", // loop if (0 < avl) + "2:", + "ret", +} +extern "C" { + pub fn aes_armv9_encdec_aes128_decrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u8, + ); +} + +#[inline(always)] +pub fn encrypt_vla(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_armv9_encdec_aes128_encrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn encrypt1(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn encrypt8(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 8) +} + +#[inline(always)] +pub fn decrypt_vla(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_armv9_encdec_aes128_decrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn decrypt1(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn decrypt8(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 8) +} diff --git a/aes/src/armv9/encdec/aes192.rs b/aes/src/armv9/encdec/aes192.rs new file mode 100644 index 00000000..304642b4 --- /dev/null +++ b/aes/src/armv9/encdec/aes192.rs @@ -0,0 +1,186 @@ +use super::RoundKeys; +use crate::{Block, Block8}; +use cipher::inout::InOut; +use core::arch::global_asm; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm!{ + ".balign 8", // align section to 8 bytes + ".global aes_armv9_encdec_aes192_encrypt", // declare symbol + ".type aes_armv9_encdec_aes192_encrypt, %function", // declare symbol as function type + "aes_armv9_encdec_aes192_encrypt:", // start function + "mov x8, #0", // set loop counter {x8} to 0 + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.none 2f", // exit if avl == 0 + + "ld1rqb {{ z0.b}}, p0/z, [x3, #0 ]", // broadcast load round 00 key + "ld1rqb {{ z1.b}}, p0/z, [x3, #16]", // broadcast load round 01 key + "ld1rqb {{ z2.b}}, p0/z, [x3, #32]", // broadcast load round 02 key + "ld1rqb {{ z3.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 03 key; increment key pointer by 4 indices + "ld1rqb {{ z4.b}}, p0/z, [x3, #0 ]", // broadcast load round 04 key + "ld1rqb {{ z5.b}}, p0/z, [x3, #16]", // broadcast load round 05 key + "ld1rqb {{ z6.b}}, p0/z, [x3, #32]", // broadcast load round 06 key + "ld1rqb {{ z7.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 07 key; increment key pointer by 4 indices + "ld1rqb {{ z8.b}}, p0/z, [x3, #0 ]", // broadcast load round 08 key + "ld1rqb {{ z9.b}}, p0/z, [x3, #16]", // broadcast load round 09 key + "ld1rqb {{z10.b}}, p0/z, [x3, #32]", // broadcast load round 10 key + "ld1rqb {{z11.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 11 key; increment key pointer by 4 indices + "ld1rqb {{z12.b}}, p0/z, [x3, #0 ]", // broadcast load round 12 key + "1:", + "ld1b z31.b, p0/z, [x1]", // load avl bytes of plain-data + + "aese z31.b, z31.b, z0.b", "aesmc z31.b, z31.b", // perform AES-128 round 00 encryption + "aese z31.b, z31.b, z1.b", "aesmc z31.b, z31.b", // perform AES-128 round 01 encryption + "aese z31.b, z31.b, z2.b", "aesmc z31.b, z31.b", // perform AES-128 round 02 encryption + "aese z31.b, z31.b, z3.b", "aesmc z31.b, z31.b", // perform AES-128 round 03 encryption + "aese z31.b, z31.b, z4.b", "aesmc z31.b, z31.b", // perform AES-128 round 04 encryption + "aese z31.b, z31.b, z5.b", "aesmc z31.b, z31.b", // perform AES-128 round 05 encryption + "aese z31.b, z31.b, z6.b", "aesmc z31.b, z31.b", // perform AES-128 round 06 encryption + "aese z31.b, z31.b, z7.b", "aesmc z31.b, z31.b", // perform AES-128 round 07 encryption + "aese z31.b, z31.b, z8.b", "aesmc z31.b, z31.b", // perform AES-128 round 08 encryption + "aese z31.b, z31.b, z9.b", "aesmc z31.b, z31.b", // perform AES-128 round 09 encryption + "aese z31.b, z31.b, z10.b", "aesmc z31.b, z31.b", // perform AES-128 round 10 encryption + "aese z31.b, z31.b, z11.b", // perform AES-128 round 11 encryption + "eor z31.b, z31.b, z12.b", // perform AES-128 round 12 encryption + + "st1b z31.b, p0, [x0]", // save avl bytes of cipher-data + + "incb x0", "incb x1", "incb x8", // increment plain-data pointer, cipher-data pointer, loop counter by avl indices + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.first 1b", // loop if (0 < avl) + "2:", + "ret", +} +extern "C" { + pub fn aes_armv9_encdec_aes192_encrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u8, + ); +} + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + ".balign 8", // align section to 8 bytes + ".global aes_armv9_encdec_aes192_decrypt", // declare symbol + ".type aes_armv9_encdec_aes192_decrypt, %function", // declare symbol as function type + "aes_armv9_encdec_aes192_decrypt:", // start function + "mov x8, #0", // set loop counter {x8} to 0 + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.none 2f", // exit if avl == 0 + + "ld1rqb {{ z0.b}}, p0/z, [x3, #0 ]", // broadcast load round 00 key + "ld1rqb {{ z1.b}}, p0/z, [x3, #16]", // broadcast load round 01 key + "ld1rqb {{ z2.b}}, p0/z, [x3, #32]", // broadcast load round 02 key + "ld1rqb {{ z3.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 03 key; increment key pointer by 4 indices + "ld1rqb {{ z4.b}}, p0/z, [x3, #0 ]", // broadcast load round 04 key + "ld1rqb {{ z5.b}}, p0/z, [x3, #16]", // broadcast load round 05 key + "ld1rqb {{ z6.b}}, p0/z, [x3, #32]", // broadcast load round 06 key + "ld1rqb {{ z7.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 07 key; increment key pointer by 4 indices + "ld1rqb {{ z8.b}}, p0/z, [x3, #0 ]", // broadcast load round 08 key + "ld1rqb {{ z9.b}}, p0/z, [x3, #16]", // broadcast load round 09 key + "ld1rqb {{z10.b}}, p0/z, [x3, #32]", // broadcast load round 10 key + "ld1rqb {{z11.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 11 key; increment key pointer by 4 indices + "ld1rqb {{z12.b}}, p0/z, [x3, #0 ]", // broadcast load round 12 key + "1:", + "ld1b z31.b, p0/z, [x1]", // load avl bytes of cipher-data + + "aesd z31.b, z31.b, z12.b", "aesimc z31.b, z31.b", // perform AES-128 round 12 decryption + "aesd z31.b, z31.b, z11.b", "aesimc z31.b, z31.b", // perform AES-128 round 11 decryption + "aesd z31.b, z31.b, z10.b", "aesimc z31.b, z31.b", // perform AES-128 round 10 decryption + "aesd z31.b, z31.b, z9.b", "aesimc z31.b, z31.b", // perform AES-128 round 09 decryption + "aesd z31.b, z31.b, z8.b", "aesimc z31.b, z31.b", // perform AES-128 round 08 decryption + "aesd z31.b, z31.b, z7.b", "aesimc z31.b, z31.b", // perform AES-128 round 07 decryption + "aesd z31.b, z31.b, z6.b", "aesimc z31.b, z31.b", // perform AES-128 round 06 decryption + "aesd z31.b, z31.b, z5.b", "aesimc z31.b, z31.b", // perform AES-128 round 05 decryption + "aesd z31.b, z31.b, z4.b", "aesimc z31.b, z31.b", // perform AES-128 round 04 decryption + "aesd z31.b, z31.b, z3.b", "aesimc z31.b, z31.b", // perform AES-128 round 03 decryption + "aesd z31.b, z31.b, z2.b", "aesimc z31.b, z31.b", // perform AES-128 round 02 decryption + "aesd z31.b, z31.b, z1.b", // perform AES-128 round 01 decryption + "eor z31.b, z31.b, z0.b", // perform AES-128 round 00 decryption + + "st1b z31.b, p0, [x0]", // save avl bytes of plain-data + + "incb x0", "incb x1", "incb x8", // increment plain-data pointer, cipher-data pointer, loop counter by avl indices + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.first 1b", // loop if (0 < avl) + "2:", + "ret", +} +extern "C" { + pub fn aes_armv9_encdec_aes192_decrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u8, + ); +} + +#[inline(always)] +fn encrypt_vla(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_armv9_encdec_aes192_encrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn encrypt1(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn encrypt8(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 8) +} + +#[inline(always)] +fn decrypt_vla(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_armv9_encdec_aes192_decrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn decrypt1(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn decrypt8(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 8) +} diff --git a/aes/src/armv9/encdec/aes256.rs b/aes/src/armv9/encdec/aes256.rs new file mode 100644 index 00000000..cf40467b --- /dev/null +++ b/aes/src/armv9/encdec/aes256.rs @@ -0,0 +1,194 @@ +use super::RoundKeys; +use crate::{Block, Block8}; +use cipher::inout::InOut; +use core::arch::global_asm; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm!{ + ".balign 8", // align section to 8 bytes + ".global aes_armv9_encdec_aes256_encrypt", // declare symbol + ".type aes_armv9_encdec_aes256_encrypt, %function", // declare symbol as function type + "aes_armv9_encdec_aes256_encrypt:", // start function + "mov x8, #0", // set loop counter {x8} to 0 + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.none 2f", // exit if avl == 0 + + "ld1rqb {{ z0.b}}, p0/z, [x3, #0 ]", // broadcast load round 00 key + "ld1rqb {{ z1.b}}, p0/z, [x3, #16]", // broadcast load round 01 key + "ld1rqb {{ z2.b}}, p0/z, [x3, #32]", // broadcast load round 02 key + "ld1rqb {{ z3.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 03 key; increment key pointer by 4 indices + "ld1rqb {{ z4.b}}, p0/z, [x3, #0 ]", // broadcast load round 04 key + "ld1rqb {{ z5.b}}, p0/z, [x3, #16]", // broadcast load round 05 key + "ld1rqb {{ z6.b}}, p0/z, [x3, #32]", // broadcast load round 06 key + "ld1rqb {{ z7.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 07 key; increment key pointer by 4 indices + "ld1rqb {{ z8.b}}, p0/z, [x3, #0 ]", // broadcast load round 08 key + "ld1rqb {{ z9.b}}, p0/z, [x3, #16]", // broadcast load round 09 key + "ld1rqb {{z10.b}}, p0/z, [x3, #32]", // broadcast load round 10 key + "ld1rqb {{z11.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 11 key; increment key pointer by 4 indices + "ld1rqb {{z12.b}}, p0/z, [x3, #0 ]", // broadcast load round 12 key + "ld1rqb {{z13.b}}, p0/z, [x3, #16]", // broadcast load round 13 key + "ld1rqb {{z14.b}}, p0/z, [x3, #32]", // broadcast load round 14 key + "1:", + "ld1b z31.b, p0/z, [x1]", // load avl bytes of plain-data + + "aese z31.b, z31.b, z0.b", "aesmc z31.b, z31.b", // perform AES-128 round 00 encryption + "aese z31.b, z31.b, z1.b", "aesmc z31.b, z31.b", // perform AES-128 round 01 encryption + "aese z31.b, z31.b, z2.b", "aesmc z31.b, z31.b", // perform AES-128 round 02 encryption + "aese z31.b, z31.b, z3.b", "aesmc z31.b, z31.b", // perform AES-128 round 03 encryption + "aese z31.b, z31.b, z4.b", "aesmc z31.b, z31.b", // perform AES-128 round 04 encryption + "aese z31.b, z31.b, z5.b", "aesmc z31.b, z31.b", // perform AES-128 round 05 encryption + "aese z31.b, z31.b, z6.b", "aesmc z31.b, z31.b", // perform AES-128 round 06 encryption + "aese z31.b, z31.b, z7.b", "aesmc z31.b, z31.b", // perform AES-128 round 07 encryption + "aese z31.b, z31.b, z8.b", "aesmc z31.b, z31.b", // perform AES-128 round 08 encryption + "aese z31.b, z31.b, z9.b", "aesmc z31.b, z31.b", // perform AES-128 round 09 encryption + "aese z31.b, z31.b, z10.b", "aesmc z31.b, z31.b", // perform AES-128 round 10 encryption + "aese z31.b, z31.b, z11.b", "aesmc z31.b, z31.b", // perform AES-128 round 11 encryption + "aese z31.b, z31.b, z12.b", "aesmc z31.b, z31.b", // perform AES-128 round 12 encryption + "aese z31.b, z31.b, z13.b", // perform AES-128 round 13 encryption + "eor z31.b, z31.b, z14.b", // perform AES-128 round 14 encryption + + "st1b z31.b, p0, [x0]", // save avl bytes of cipher-data + + "incb x0", "incb x1", "incb x8", // increment plain-data pointer, cipher-data pointer, loop counter by avl indices + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.first 1b", // exit if (0 < avl) + "2:", + "ret", +} +extern "C" { + pub fn aes_armv9_encdec_aes256_encrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u8, + ); +} + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + ".balign 8", // align section to 8 bytes + ".global aes_armv9_encdec_aes256_decrypt", // declare symbol + ".type aes_armv9_encdec_aes256_decrypt, %function", // declare symbol as function type + "aes_armv9_encdec_aes256_decrypt:", // start function + "mov x8, #0", // set x8 to 0 + + "whilelt p0.b, x8, x2", // set p0.b to 1 if (0 {x8} < len {x2}), otherwise 0 + "b.none 2f", // branch and exit early if !(0 < len) + + "ld1rqb {{ z0.b}}, p0/z, [x3, #0 ]", // broadcast load round 00 key + "ld1rqb {{ z1.b}}, p0/z, [x3, #16]", // broadcast load round 01 key + "ld1rqb {{ z2.b}}, p0/z, [x3, #32]", // broadcast load round 02 key + "ld1rqb {{ z3.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 03 key; increment key pointer by 4 indices + "ld1rqb {{ z4.b}}, p0/z, [x3, #0 ]", // broadcast load round 04 key + "ld1rqb {{ z5.b}}, p0/z, [x3, #16]", // broadcast load round 05 key + "ld1rqb {{ z6.b}}, p0/z, [x3, #32]", // broadcast load round 06 key + "ld1rqb {{ z7.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 07 key; increment key pointer by 4 indices + "ld1rqb {{ z8.b}}, p0/z, [x3, #0 ]", // broadcast load round 08 key + "ld1rqb {{ z9.b}}, p0/z, [x3, #16]", // broadcast load round 09 key + "ld1rqb {{z10.b}}, p0/z, [x3, #32]", // broadcast load round 10 key + "ld1rqb {{z11.b}}, p0/z, [x3, #48]", "add x3, x3, #64", // broadcast load round 11 key; increment key pointer by 4 indices + "ld1rqb {{z12.b}}, p0/z, [x3, #0 ]", // broadcast load round 12 key + "ld1rqb {{z13.b}}, p0/z, [x3, #16]", // broadcast load round 13 key + "ld1rqb {{z14.b}}, p0/z, [x3, #32]", // broadcast load round 14 key + "1:", + "ld1b z31.b, p0/z, [x1]", // load vl bytes of cipher-data + + "aesd z31.b, z31.b, z14.b", "aesimc z31.b, z31.b", // perform AES-128 round 14 decryption + "aesd z31.b, z31.b, z13.b", "aesimc z31.b, z31.b", // perform AES-128 round 13 decryption + "aesd z31.b, z31.b, z12.b", "aesimc z31.b, z31.b", // perform AES-128 round 12 decryption + "aesd z31.b, z31.b, z11.b", "aesimc z31.b, z31.b", // perform AES-128 round 11 decryption + "aesd z31.b, z31.b, z10.b", "aesimc z31.b, z31.b", // perform AES-128 round 10 decryption + "aesd z31.b, z31.b, z9.b", "aesimc z31.b, z31.b", // perform AES-128 round 09 decryption + "aesd z31.b, z31.b, z8.b", "aesimc z31.b, z31.b", // perform AES-128 round 08 decryption + "aesd z31.b, z31.b, z7.b", "aesimc z31.b, z31.b", // perform AES-128 round 07 decryption + "aesd z31.b, z31.b, z6.b", "aesimc z31.b, z31.b", // perform AES-128 round 06 decryption + "aesd z31.b, z31.b, z5.b", "aesimc z31.b, z31.b", // perform AES-128 round 05 decryption + "aesd z31.b, z31.b, z4.b", "aesimc z31.b, z31.b", // perform AES-128 round 04 decryption + "aesd z31.b, z31.b, z3.b", "aesimc z31.b, z31.b", // perform AES-128 round 03 decryption + "aesd z31.b, z31.b, z2.b", "aesimc z31.b, z31.b", // perform AES-128 round 02 decryption + "aesd z31.b, z31.b, z1.b", // perform AES-128 round 01 decryption + "eor z31.b, z31.b, z0.b", // perform AES-128 round 00 decryption + + "st1b z31.b, p0, [x0]", // save avl bytes of plain-data + + "incb x0", "incb x1", "incb x8", // increment plain-data pointer, cipher-data pointer, loop counter by avl indices + + "whilelt p0.b, x8, x2", // set avl to len {x2} - loop counter {x8}, represented as a predicate + "b.first 1b", // exit if (0 < avl) + "2:", + "ret", +} +extern "C" { + pub fn aes_armv9_encdec_aes256_decrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u8, + ); +} + +#[inline(always)] +fn encrypt_vla(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_armv9_encdec_aes256_encrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn encrypt1(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn encrypt8(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 8) +} + +#[inline(always)] +fn decrypt_vla(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_armv9_encdec_aes256_decrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn decrypt1(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn decrypt8(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 8) +} diff --git a/aes/src/armv9/expand.rs b/aes/src/armv9/expand.rs new file mode 100644 index 00000000..4b828af4 --- /dev/null +++ b/aes/src/armv9/expand.rs @@ -0,0 +1,76 @@ +use super::RoundKeys; + +pub(super) mod aes128; +pub(super) mod aes192; +pub(super) mod aes256; + +use core::{arch::aarch64::*, mem, slice}; + +// TODO(silvanshade): remove this and replace with armv9 sve2 key expansion + +/// There are 4 AES words in a block. +const BLOCK_WORDS: usize = 4; + +/// The AES (nee Rijndael) notion of a word is always 32-bits, or 4-bytes. +const WORD_SIZE: usize = 4; + +/// AES round constants. +const ROUND_CONSTS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]; + +/// AES key expansion. +#[target_feature(enable = "aes")] +pub unsafe fn expand_key(key: &[u8; L]) -> [uint8x16_t; N] { + assert!((L == 16 && N == 11) || (L == 24 && N == 13) || (L == 32 && N == 15)); + + let mut expanded_keys: [uint8x16_t; N] = mem::zeroed(); + + let columns = + slice::from_raw_parts_mut(expanded_keys.as_mut_ptr() as *mut u32, N * BLOCK_WORDS); + + for (i, chunk) in key.chunks_exact(WORD_SIZE).enumerate() { + columns[i] = u32::from_ne_bytes(chunk.try_into().unwrap()); + } + + // From "The Rijndael Block Cipher" Section 4.1: + // > The number of columns of the Cipher Key is denoted by `Nk` and is + // > equal to the key length divided by 32 [bits]. + let nk = L / WORD_SIZE; + + for i in nk..(N * BLOCK_WORDS) { + let mut word = columns[i - 1]; + + if i % nk == 0 { + word = sub_word(word).rotate_right(8) ^ ROUND_CONSTS[i / nk - 1]; + } else if nk > 6 && i % nk == 4 { + word = sub_word(word); + } + + columns[i] = columns[i - nk] ^ word; + } + + expanded_keys +} + +/// Compute inverse expanded keys (for decryption). +/// +/// This is the reverse of the encryption keys, with the Inverse Mix Columns +/// operation applied to all but the first and last expanded key. +#[target_feature(enable = "aes")] +pub(super) unsafe fn inv_expanded_keys(expanded_keys: &mut [uint8x16_t; N]) { + assert!(N == 11 || N == 13 || N == 15); + for ek in expanded_keys.iter_mut().take(N - 1).skip(1) { + *ek = vaesimcq_u8(*ek); + } +} + +/// Sub bytes for a single AES word: used for key expansion. +#[inline] +#[target_feature(enable = "aes")] +unsafe fn sub_word(input: u32) -> u32 { + let input = vreinterpretq_u8_u32(vdupq_n_u32(input)); + + // AES single round encryption (with a "round" key of all zeros) + let sub_input = vaeseq_u8(input, vdupq_n_u8(0)); + + vgetq_lane_u32(vreinterpretq_u32_u8(sub_input), 0) +} diff --git a/aes/src/armv9/expand/aes128.rs b/aes/src/armv9/expand/aes128.rs new file mode 100644 index 00000000..ecf41201 --- /dev/null +++ b/aes/src/armv9/expand/aes128.rs @@ -0,0 +1,24 @@ +use crate::armv9::expand::RoundKeys; +use cipher::{array::Array, typenum::U16}; + +// // TODO(silvanshade): switch to intrinsics when available +// #[rustfmt::skip] +// global_asm! { +// ".balign 8", +// ".global aes_armv9_expand_aes128_expand_key", +// ".type aes_armv9_expand_aes128_expand_key, %function", +// "aes_armv9_expand_aes128_expand_key:", +// } +// extern "C" { +// fn aes_armv9_expand_aes128_expand_key(dst: *mut u8, src: *const u8); +// } + +#[inline(always)] +pub fn expand_key(key: &Array) -> RoundKeys<11> { + unsafe { crate::armv9::expand::expand_key(key.as_ref()) } +} + +#[inline(always)] +pub fn inv_expanded_keys(expanded_keys: &mut RoundKeys<11>) { + unsafe { crate::armv9::expand::inv_expanded_keys(expanded_keys) } +} diff --git a/aes/src/armv9/expand/aes192.rs b/aes/src/armv9/expand/aes192.rs new file mode 100644 index 00000000..63e0eb75 --- /dev/null +++ b/aes/src/armv9/expand/aes192.rs @@ -0,0 +1,25 @@ +use cipher::{array::Array, typenum::U24}; + +use super::RoundKeys; + +// // TODO(silvanshade): switch to intrinsics when available +// #[rustfmt::skip] +// global_asm! { +// ".balign 8", +// ".global aes_armv9_expand_aes192_expand_key", +// ".type aes_armv9_expand_aes192_expand_key, %function", +// "aes_armv9_expand_aes192_expand_key:", +// } +// extern "C" { +// fn aes_armv9_expand_aes192_expand_key(dst: *mut u8, src: *const u8); +// } + +#[inline(always)] +pub fn expand_key(key: &Array) -> RoundKeys<13> { + unsafe { crate::armv9::expand::expand_key(key.as_ref()) } +} + +#[inline(always)] +pub fn inv_expanded_keys(expanded_keys: &mut RoundKeys<13>) { + unsafe { crate::armv9::expand::inv_expanded_keys(expanded_keys) } +} diff --git a/aes/src/armv9/expand/aes256.rs b/aes/src/armv9/expand/aes256.rs new file mode 100644 index 00000000..77aaa67b --- /dev/null +++ b/aes/src/armv9/expand/aes256.rs @@ -0,0 +1,25 @@ +use cipher::{array::Array, typenum::U32}; + +use super::RoundKeys; + +// // TODO(silvanshade): switch to intrinsics when available +// #[rustfmt::skip] +// global_asm! { +// ".balign 8", +// ".global aes_armv9_expand_aes256_expand_key", +// ".type aes_armv9_expand_aes256_expand_key, %function", +// "aes_armv9_expand_aes256_expand_key:", +// } +// extern "C" { +// fn aes_armv9_expand_aes256_expand_key(dst: *mut u8, src: *const u8); +// } + +#[inline(always)] +pub fn expand_key(key: &Array) -> RoundKeys<15> { + unsafe { crate::armv9::expand::expand_key(key.as_ref()) } +} + +#[inline(always)] +pub fn inv_expanded_keys(expanded_keys: &mut RoundKeys<15>) { + unsafe { crate::armv9::expand::inv_expanded_keys(expanded_keys) } +} diff --git a/aes/src/armv9/test_expand.rs b/aes/src/armv9/test_expand.rs new file mode 100644 index 00000000..ab0d886d --- /dev/null +++ b/aes/src/armv9/test_expand.rs @@ -0,0 +1,166 @@ +use crate::armv9::RoundKeys; +use core::arch::aarch64::*; +use hex_literal::hex; + +const AES128_KEY: [u8; 16] = hex!("2b7e151628aed2a6abf7158809cf4f3c"); +const AES128_EXP_KEYS: [[u8; 16]; 11] = [ + AES128_KEY, + hex!("a0fafe1788542cb123a339392a6c7605"), + hex!("f2c295f27a96b9435935807a7359f67f"), + hex!("3d80477d4716fe3e1e237e446d7a883b"), + hex!("ef44a541a8525b7fb671253bdb0bad00"), + hex!("d4d1c6f87c839d87caf2b8bc11f915bc"), + hex!("6d88a37a110b3efddbf98641ca0093fd"), + hex!("4e54f70e5f5fc9f384a64fb24ea6dc4f"), + hex!("ead27321b58dbad2312bf5607f8d292f"), + hex!("ac7766f319fadc2128d12941575c006e"), + hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"), +]; +const AES128_EXP_INVKEYS: [[u8; 16]; 11] = [ + AES128_KEY, + hex!("2b3708a7f262d405bc3ebdbf4b617d62"), + hex!("cc7505eb3e17d1ee82296c51c9481133"), + hex!("7c1f13f74208c219c021ae480969bf7b"), + hex!("90884413d280860a12a128421bc89739"), + hex!("6ea30afcbc238cf6ae82a4b4b54a338d"), + hex!("6efcd876d2df54807c5df034c917c3b9"), + hex!("12c07647c01f22c7bc42d2f37555114a"), + hex!("df7d925a1f62b09da320626ed6757324"), + hex!("0c7b5a631319eafeb0398890664cfbb4"), + hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"), +]; + +const AES192_KEY: [u8; 24] = hex!("8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b"); +const AES192_EXP_KEYS: [[u8; 16]; 13] = [ + hex!("8e73b0f7da0e6452c810f32b809079e5"), + hex!("62f8ead2522c6b7bfe0c91f72402f5a5"), + hex!("ec12068e6c827f6b0e7a95b95c56fec2"), + hex!("4db7b4bd69b5411885a74796e92538fd"), + hex!("e75fad44bb095386485af05721efb14f"), + hex!("a448f6d94d6dce24aa326360113b30e6"), + hex!("a25e7ed583b1cf9a27f939436a94f767"), + hex!("c0a69407d19da4e1ec1786eb6fa64971"), + hex!("485f703222cb8755e26d135233f0b7b3"), + hex!("40beeb282f18a2596747d26b458c553e"), + hex!("a7e1466c9411f1df821f750aad07d753"), + hex!("ca4005388fcc5006282d166abc3ce7b5"), + hex!("e98ba06f448c773c8ecc720401002202"), +]; +const AES192_EXP_INVKEYS: [[u8; 16]; 13] = [ + hex!("8e73b0f7da0e6452c810f32b809079e5"), + hex!("9eb149c479d69c5dfeb4a27ceab6d7fd"), + hex!("659763e78c817087123039436be6a51e"), + hex!("41b34544ab0592b9ce92f15e421381d9"), + hex!("5023b89a3bc51d84d04b19377b4e8b8e"), + hex!("b5dc7ad0f7cffb09a7ec43939c295e17"), + hex!("c5ddb7f8be933c760b4f46a6fc80bdaf"), + hex!("5b6cfe3cc745a02bf8b9a572462a9904"), + hex!("4d65dfa2b1e5620dea899c312dcc3c1a"), + hex!("f3b42258b59ebb5cf8fb64fe491e06f3"), + hex!("a3979ac28e5ba6d8e12cc9e654b272ba"), + hex!("ac491644e55710b746c08a75c89b2cad"), + hex!("e98ba06f448c773c8ecc720401002202"), +]; + +const AES256_KEY: [u8; 32] = + hex!("603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4"); +const AES256_EXP_KEYS: [[u8; 16]; 15] = [ + hex!("603deb1015ca71be2b73aef0857d7781"), + hex!("1f352c073b6108d72d9810a30914dff4"), + hex!("9ba354118e6925afa51a8b5f2067fcde"), + hex!("a8b09c1a93d194cdbe49846eb75d5b9a"), + hex!("d59aecb85bf3c917fee94248de8ebe96"), + hex!("b5a9328a2678a647983122292f6c79b3"), + hex!("812c81addadf48ba24360af2fab8b464"), + hex!("98c5bfc9bebd198e268c3ba709e04214"), + hex!("68007bacb2df331696e939e46c518d80"), + hex!("c814e20476a9fb8a5025c02d59c58239"), + hex!("de1369676ccc5a71fa2563959674ee15"), + hex!("5886ca5d2e2f31d77e0af1fa27cf73c3"), + hex!("749c47ab18501ddae2757e4f7401905a"), + hex!("cafaaae3e4d59b349adf6acebd10190d"), + hex!("fe4890d1e6188d0b046df344706c631e"), +]; +const AES256_EXP_INVKEYS: [[u8; 16]; 15] = [ + hex!("603deb1015ca71be2b73aef0857d7781"), + hex!("8ec6bff6829ca03b9e49af7edba96125"), + hex!("42107758e9ec98f066329ea193f8858b"), + hex!("4a7459f9c8e8f9c256a156bc8d083799"), + hex!("6c3d632985d1fbd9e3e36578701be0f3"), + hex!("54fb808b9c137949cab22ff547ba186c"), + hex!("25ba3c22a06bc7fb4388a28333934270"), + hex!("d669a7334a7ade7a80c8f18fc772e9e3"), + hex!("c440b289642b757227a3d7f114309581"), + hex!("32526c367828b24cf8e043c33f92aa20"), + hex!("34ad1e4450866b367725bcc763152946"), + hex!("b668b621ce40046d36a047ae0932ed8e"), + hex!("57c96cf6074f07c0706abb07137f9241"), + hex!("ada23f4963e23b2455427c8a5c709104"), + hex!("fe4890d1e6188d0b046df344706c631e"), +]; + +fn load_expanded_keys(input: [[u8; 16]; N]) -> RoundKeys { + let mut output = [unsafe { vdupq_n_u8(0) }; N]; + + for (src, dst) in input.iter().zip(output.iter_mut()) { + *dst = unsafe { vld1q_u8(src.as_ptr()) } + } + + output +} + +fn store_expanded_keys(input: RoundKeys) -> [[u8; 16]; N] { + let mut output = [[0u8; 16]; N]; + + for (src, dst) in input.iter().zip(output.iter_mut()) { + unsafe { vst1q_u8(dst.as_mut_ptr(), *src) } + } + + output +} + +// NOTE: Unlike RISC-V scalar crypto instructions, RISC-V vector crypto instructions implicitly +// perform key inversion as part of the cipher coding instructions. There are no distinct vector +// instructions for key inversion. Hence, no definition of `inv_expanded_keys` used below. + +#[test] +fn aes128_key_expansion() { + let key = AES128_KEY; + let ek = crate::armv9::expand::aes128::expand_key(&key.into()); + assert_eq!(store_expanded_keys(ek), AES128_EXP_KEYS); +} + +#[test] +fn aes128_key_expansion_inv() { + let mut ek = load_expanded_keys(AES128_EXP_KEYS); + crate::armv9::expand::aes128::inv_expanded_keys(&mut ek); + assert_eq!(store_expanded_keys(ek), AES128_EXP_INVKEYS); +} + +#[test] +fn aes192_key_expansion() { + let key = AES192_KEY; + let ek = crate::armv9::expand::aes192::expand_key(&key.into()); + assert_eq!(store_expanded_keys(ek), AES192_EXP_KEYS); +} + +#[test] +fn aes192_key_expansion_inv() { + let mut ek = load_expanded_keys(AES192_EXP_KEYS); + crate::armv9::expand::aes192::inv_expanded_keys(&mut ek); + assert_eq!(store_expanded_keys(ek), AES192_EXP_INVKEYS); +} + +#[test] +fn aes256_key_expansion() { + let key = AES256_KEY; + let ek = crate::armv9::expand::aes256::expand_key(&key.into()); + assert_eq!(store_expanded_keys(ek), AES256_EXP_KEYS); +} + +#[test] +fn aes256_key_expansion_inv() { + let mut ek = load_expanded_keys(AES256_EXP_KEYS); + crate::armv9::expand::aes256::inv_expanded_keys(&mut ek); + assert_eq!(store_expanded_keys(ek), AES256_EXP_INVKEYS); +} diff --git a/aes/src/hazmat.rs b/aes/src/hazmat.rs index 3d4def91..649eab04 100644 --- a/aes/src/hazmat.rs +++ b/aes/src/hazmat.rs @@ -11,6 +11,8 @@ //! We do NOT recommend using it to implement any algorithm which has not //! received extensive peer review by cryptographers. +// TODO(silvanshade): armv9 sve2 hazmat + use crate::{soft::fixslice::hazmat as soft, Block, Block8}; #[cfg(all(target_arch = "aarch64", not(aes_force_soft)))] diff --git a/aes/src/lib.rs b/aes/src/lib.rs index 0f8bab50..bb68368e 100644 --- a/aes/src/lib.rs +++ b/aes/src/lib.rs @@ -128,10 +128,15 @@ mod soft; use cfg_if::cfg_if; cfg_if! { - if #[cfg(all(target_arch = "aarch64", not(aes_force_soft)))] { + if #[cfg(all(target_arch = "aarch64", not(target_feature = "sve2-aes"), not(aes_force_soft)))] { mod armv8; mod autodetect; pub use autodetect::*; + } else if #[cfg(all(target_arch = "aarch64", target_feature = "sve2-aes", not(aes_force_soft)))] { + #[cfg(feature = "hazmat")] // TODO(silvanshade): remove once armv9 sve2 hazmat is implemented + mod armv8; + mod armv9; + pub use armv9::*; } else if #[cfg(all( any(target_arch = "x86", target_arch = "x86_64"), not(aes_force_soft)