From 7326618da601a4d6d28ad29e33ae33ce6ffd869d Mon Sep 17 00:00:00 2001 From: "Tomohisa Tanaka (Maroontress)" Date: Wed, 16 Aug 2023 18:52:58 +0900 Subject: [PATCH] Use ARM Neon Intrinsics (v7) (#7) --- libmimicssl-aes128-cbc-decrypt/CMakeLists.txt | 14 +- .../src/arm_v7_Aes128Cbc.c | 175 +++++++++ testsuite/CMakeLists.txt | 16 +- testsuite/arm_v7_main.cxx | 359 ++++++++++++++++++ 4 files changed, 561 insertions(+), 3 deletions(-) create mode 100644 libmimicssl-aes128-cbc-decrypt/src/arm_v7_Aes128Cbc.c create mode 100644 testsuite/arm_v7_main.cxx diff --git a/libmimicssl-aes128-cbc-decrypt/CMakeLists.txt b/libmimicssl-aes128-cbc-decrypt/CMakeLists.txt index 5a0a36b..443efe7 100644 --- a/libmimicssl-aes128-cbc-decrypt/CMakeLists.txt +++ b/libmimicssl-aes128-cbc-decrypt/CMakeLists.txt @@ -13,7 +13,19 @@ endif() add_library(mimicssl-aes128-cbc-decrypt STATIC) add_library(mimicssl-aes128-cbc-decrypt-shared SHARED) -if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin" +if("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") + if("${CMAKE_ANDROID_ARCH_ABI}" STREQUAL "arm64-v8a") + set(SOURCES src/evp.c src/arm_v7_Aes128Cbc.c) + elseif("${CMAKE_ANDROID_ARCH_ABI}" STREQUAL "armeabi-v7a") + set(OPTIONS -mfpu=neon) + set(SOURCES src/evp.c src/arm_v7_Aes128Cbc.c) + elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64") + set(OPTIONS -msse3 -maes) + set(SOURCES src/evp.c src/x86_64_Aes128Cbc.c) + else() + set(SOURCES src/evp.c src/Aes128Cbc.c) + endif() +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin" AND "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64") set(SOURCES src/evp.c src/aarch64_Aes128Cbc.c) elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64") diff --git a/libmimicssl-aes128-cbc-decrypt/src/arm_v7_Aes128Cbc.c b/libmimicssl-aes128-cbc-decrypt/src/arm_v7_Aes128Cbc.c new file mode 100644 index 0000000..e39c074 --- /dev/null +++ b/libmimicssl-aes128-cbc-decrypt/src/arm_v7_Aes128Cbc.c @@ -0,0 +1,175 @@ +/* + References: + + https://en.wikipedia.org/wiki/Advanced_Encryption_Standard + https://en.wikipedia.org/wiki/AES_key_schedule + https://en.wikipedia.org/wiki/Rijndael_S-box + https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197.pdf +*/ + +/* + References of ARM Neon Intrinsics: + https://arm-software.github.io/acle/neon_intrinsics/advsimd.html + https://developer.arm.com/documentation/dui0472/m/Using-NEON-Support?lang=en +*/ + +#include +#include "Aes128Cbc.h" + +#include "sbox.h" +#include "rsbox.h" +#include "rcon.h" +#include "multiply.h" + +static void +keyExpansion(const struct Aes128Cbc_Key *key, struct Aes128Cbc_RoundKey *out) +{ + out->round[0] = *key; + uint32_t *abcd = (uint32_t *)&(out->round[0].data[12]); + uint32_t t = *abcd; + + // prev **** **** abcd + // next + for (uint32_t round = 1; round < 11; ++round) { + uint32_t t0 = (t >> 8) | (t << 24); + uint8_t a0 = (uint8_t)t0; + uint8_t b0 = (uint8_t)(t0 >> 8); + uint8_t c0 = (uint8_t)(t0 >> 16); + uint8_t d0 = (uint8_t)(t0 >> 24); + t = (SBOX[a0] ^ RCON[round - 1]) + | (SBOX[b0] << 8) + | (SBOX[c0] << 16) + | (SBOX[d0] << 24); + const uint32_t *prev = (const uint32_t *)&out->round[round - 1]; + uint32_t *next = (uint32_t *)&out->round[round]; + for (uint32_t j = 0; j < 4; ++j) { + t ^= *prev; + *next = t; + ++next; + ++prev; + } + } +} + +static uint8x16_t +addRoundKey(uint8x16_t state, const struct Aes128Cbc_Key *key) +{ + uint8x16_t r = vld1q_u8(key->data); + return veorq_u8(state, r); +} + +static uint8x16_t +invShiftRowsSubBytes(uint8x16_t state) +{ + uint8_t s[16]; + uint8_t o[16]; + + // 0 1 2 3 4 5 6 7 8 9 A B C D E F + // | | | | + // 0 D A 7 4 1 E B 8 5 2 F C 9 6 3 + vst1q_u8(s, state); + o[0] = RSBOX[s[0]]; + o[1] = RSBOX[s[13]]; + o[2] = RSBOX[s[10]]; + o[3] = RSBOX[s[7]]; + o[4] = RSBOX[s[4]]; + o[5] = RSBOX[s[1]]; + o[6] = RSBOX[s[14]]; + o[7] = RSBOX[s[11]]; + o[8] = RSBOX[s[8]]; + o[9] = RSBOX[s[5]]; + o[10] = RSBOX[s[2]]; + o[11] = RSBOX[s[15]]; + o[12] = RSBOX[s[12]]; + o[13] = RSBOX[s[9]]; + o[14] = RSBOX[s[6]]; + o[15] = RSBOX[s[3]]; + return vld1q_u8(o); +} + +static uint8x16_t +invMixColumns(uint8x16_t state) +{ + uint8_t s[16]; + uint32_t v[4]; + + vst1q_u8(s, state); + v[0] = MULTIPLY_0[s[0]]; + v[1] = MULTIPLY_0[s[4]]; + v[2] = MULTIPLY_0[s[8]]; + v[3] = MULTIPLY_0[s[12]]; + uint8x16_t a = vld1q_u8((uint8_t *)v); + + v[0] = MULTIPLY_1[s[1]]; + v[1] = MULTIPLY_1[s[5]]; + v[2] = MULTIPLY_1[s[9]]; + v[3] = MULTIPLY_1[s[13]]; + uint8x16_t b = vld1q_u8((uint8_t *)v); + + v[0] = MULTIPLY_2[s[2]]; + v[1] = MULTIPLY_2[s[6]]; + v[2] = MULTIPLY_2[s[10]]; + v[3] = MULTIPLY_2[s[14]]; + uint8x16_t c = vld1q_u8((uint8_t *)v); + + v[0] = MULTIPLY_3[s[3]]; + v[1] = MULTIPLY_3[s[7]]; + v[2] = MULTIPLY_3[s[11]]; + v[3] = MULTIPLY_3[s[15]]; + uint8x16_t d = vld1q_u8((uint8_t *)v); + + return veorq_u8(veorq_u8(a, b), veorq_u8(c, d)); +} + +static void +postKeyExpansion(struct Aes128Cbc_RoundKey *roundKey) +{ + for (int k = 1; k < 10; ++k) { + struct Aes128Cbc_Key *key = &roundKey->round[k]; + uint8x16_t o = vld1q_u8(key->data); + vst1q_u8(key->data, invMixColumns(o)); + } +} + +void +Aes128Cbc_init(struct Aes128Cbc *ctx, const struct Aes128Cbc_Key *key, + const struct Aes128Cbc_Iv *iv) +{ + keyExpansion(key, &ctx->roundKey); + postKeyExpansion(&ctx->roundKey); + ctx->iv = *iv; +} + +static uint8x16_t +eqInvCipher(uint8x16_t state, const struct Aes128Cbc_RoundKey *roundKey) +{ + const struct Aes128Cbc_Key *key = &roundKey->round[10]; + uint8x16_t newState = addRoundKey(state, key); + for (uint32_t round = 9; round > 0; --round) { + --key; + newState = invShiftRowsSubBytes(newState); + newState = invMixColumns(newState); + newState = addRoundKey(newState, key); + } + --key; + newState = invShiftRowsSubBytes(newState); + return addRoundKey(newState, key); +} + +void Aes128Cbc_decrypt(struct Aes128Cbc *ctx, const void *data, + size_t length, void *output) +{ + const uint8_t *in = (const uint8_t *)data; + uint8_t *out = (uint8_t *)output; + uint8x16_t iv128 = vld1q_u8(ctx->iv.data); + while (length > 0) { + uint8x16_t in128 = vld1q_u8(in); + uint8x16_t state = eqInvCipher(in128, &ctx->roundKey); + vst1q_u8(out, veorq_u8(state, iv128)); + iv128 = in128; + in += 16; + out += 16; + length -= 16; + } + vst1q_u8(ctx->iv.data, iv128); +} diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index 8f15462..2e0b88f 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -9,7 +9,19 @@ enable_testing() add_executable(testsuite) -if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin" +if("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") + if("${CMAKE_ANDROID_ARCH_ABI}" STREQUAL "arm64-v8a") + set(SOURCES arm_v7_main.cxx) + elseif("${CMAKE_ANDROID_ARCH_ABI}" STREQUAL "armeabi-v7a") + set(OPTIONS -mfpu=neon) + set(SOURCES arm_v7_main.cxx) + elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64") + set(OPTIONS -msse3 -maes) + set(SOURCES x86_64_main.cxx) + else() + set(SOURCES main.cxx) + endif() +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin" AND "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64") set(SOURCES aarch64_main.cxx) elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64") @@ -31,4 +43,4 @@ target_include_directories(testsuite PRIVATE target_link_libraries(testsuite mimicssl-aes128-cbc-decrypt) include(GoogleTest) -gtest_discover_tests(testsuite) +gtest_discover_tests(testsuite DISCOVERY_TIMEOUT 50) diff --git a/testsuite/arm_v7_main.cxx b/testsuite/arm_v7_main.cxx new file mode 100644 index 0000000..0478ef6 --- /dev/null +++ b/testsuite/arm_v7_main.cxx @@ -0,0 +1,359 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "expect.hxx" + +class Driver final { +private: + std::map> map; + bool list = false; + std::optional name; + +public: + Driver(const char* const* args) { + for (auto* a = args; *a != nullptr; ++a) { + if (std::strcmp(*a, "--gtest_list_tests") == 0) { + list = true; + return; + } + } + for (auto* a = args; *a != nullptr; ++a) { + auto o = std::string {*a}; + auto prefix = std::string {"--gtest_filter=main."}; + if (o.starts_with(prefix)) { + name = std::make_optional(o.substr(prefix.length())); + return; + } + } + } + + void add(const std::string& name, const std::function& testcase) { + map[name] = testcase; + } + + int run() const { + if (list) { + std::cout << "main.\n"; + for (auto i = map.cbegin(); i != map.cend(); ++i) { + auto [name, testcase] = *i; + std::cout << " " << name << "\n"; + } + return 0; + } + if (name.has_value()) { + std::string n = name.value(); + auto testcase = map.at(n); + try { + testcase(); + std::cout << n << ": succeeded\n"; + } catch (std::runtime_error& e) { + std::cout << n << ": failed\n"; + std::cout << " " << e.what(); + return 1; + } + return 0; + } + auto count = 0; + for (auto i = map.cbegin(); i != map.cend(); ++i) { + auto [name, testcase] = *i; + try { + testcase(); + std::cout << name << ": succeeded\n"; + } catch (std::runtime_error& e) { + std::cout << name << ": failed\n"; + std::cout << " " << e.what(); + ++count; + } + } + if (count == 0) { + std::cout << "all tests passed\n"; + return 0; + } + std::cout << count << " test(s) failed.\n"; + return 1; + } +}; + +namespace maroontress::lighter { + template <> std::string toString(std::uint8_t b) { + auto v = (uint32_t)b; + std::ostringstream out; + out << std::dec << v << " (0x" << std::hex << v << ")"; + return out.str(); + } +} + +static auto +toArray(const std::string& m) -> std::array +{ + std::uint8_t array[16]; + + if (m.length() != 32) { + throw std::runtime_error("invalid length"); + } + for (auto k = 0; k < 16; ++k) { + auto p = m.substr(k * 2, 2); + array[k] = (std::uint8_t)std::stoull(p, nullptr, 16); + } + return std::to_array(array); +} + +#include "Aes128Cbc.h" +#include "arm_v7_Aes128Cbc.c" +#include "evp.h" + +static auto +toKey(const std::string& m) -> Aes128Cbc_Key +{ + Aes128Cbc_Key key; + + auto array = toArray(m); + uint32_t *v = (uint32_t *)array.data(); + uint32_t *o = (uint32_t *)key.data; + o[0] = v[0]; + o[1] = v[1]; + o[2] = v[2]; + o[3] = v[3]; + return key; +} + +static auto +toState(const std::string& m) -> uint8x16_t +{ + auto array = toArray(m); + auto state = vld1q_u8(array.data()); + return state; +} + +static const bool dumpEnabled = false; + +static void +dump(const uint8_t* b) +{ + if constexpr (!dumpEnabled) { + return; + } + for (int k = 0; k < 16; ++k) { + std::printf("%02x", *b); + ++b; + } + std::printf("\n"); +} + +static void +dump(uint8x16_t state) +{ + uint8_t data[16]; + vst1q_u8(data, state); + dump(data); +} + +static void +checkKeyExpansion(std::array& key, + std::array& expectedList) +{ + struct Aes128Cbc_Key k; + std::memcpy(k.data, &key[0], 16); + struct Aes128Cbc_RoundKey out; + keyExpansion(&k, &out); + for (int i = 0; i <= 10; ++i) { + dump(out.round[i].data); + } + for (auto j = 0; j <= 10; ++j) { + auto expected = toArray(expectedList[j]); + auto actual = out.round[j].data; + for (auto k = 0; k < 16; ++k) { + expect(actual[k]) == expected[k]; + } + } +} + +template +static int +decrypt(FILE* in, FILE* out) +{ + static_assert(N % 16 == 0); + unsigned char inbuf[N]; + unsigned char outbuf[N]; + int inlen; + int outlen; + EVP_CIPHER_CTX* ctx; + // KEY "0123456789abcdeF" -> 30313233343536373839616263646546 + unsigned char key[] = "0123456789abcdeF"; + // IV "1234567887654321" -> 31323334353637383837363534333231 + unsigned char iv[] = "1234567887654321"; + + if ((ctx = EVP_CIPHER_CTX_new()) == NULL) { + return 0; + } + if (!EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, iv)) { + EVP_CIPHER_CTX_free(ctx); + return 0; + } + for (;;) { + inlen = fread(inbuf, 1, sizeof(inbuf), in); + if (inlen <= 0) { + break; + } + if (!EVP_DecryptUpdate(ctx, outbuf, &outlen, inbuf, inlen)) { + EVP_CIPHER_CTX_free(ctx); + return 0; + } + if (outlen > 0) { + fwrite(outbuf, 1, outlen, out); + } + } + if (!EVP_DecryptFinal_ex(ctx, outbuf, &outlen)) { + EVP_CIPHER_CTX_free(ctx); + return 0; + } + if (outlen > 0) { + fwrite(outbuf, 1, outlen, out); + } + EVP_CIPHER_CTX_free(ctx); + return 1; +} + +static void +compare(const char* fileOne, const char* fileTwo) +{ + auto sizeOfOne = std::filesystem::file_size(fileOne); + auto sizeOfTwo = std::filesystem::file_size(fileTwo); + expect(sizeOfOne) == sizeOfTwo; + std::ifstream one {fileOne, std::ios_base::in | std::ios_base::binary}; + std::ifstream two {fileTwo, std::ios_base::in | std::ios_base::binary}; + expect(!one).isFalse(); + expect(!two).isFalse(); + for (;;) { + if (one.eof()) { + break; + } + auto charOne = one.get(); + auto charTwo = two.get(); + expect(charOne) == charTwo; + } +} + +int +main(int ac, char** av) +{ + auto driver = Driver {av}; + driver.add("endian", [] { + expect(std::endian::native) == std::endian::little; + }); + driver.add("keyExpansion (test vector)", [] { + auto key = toArray("2b7e151628aed2a6abf7158809cf4f3c"); + std::array expectedList = { + "2b7e151628aed2a6abf7158809cf4f3c", + "a0fafe1788542cb123a339392a6c7605", + "f2c295f27a96b9435935807a7359f67f", + "3d80477d4716fe3e1e237e446d7a883b", + "ef44a541a8525b7fb671253bdb0bad00", + "d4d1c6f87c839d87caf2b8bc11f915bc", + "6d88a37a110b3efddbf98641ca0093fd", + "4e54f70e5f5fc9f384a64fb24ea6dc4f", + "ead27321b58dbad2312bf5607f8d292f", + "ac7766f319fadc2128d12941575c006e", + "d014f9a8c9ee2589e13f0cc8b6630ca6"}; + checkKeyExpansion(key, expectedList); + }); + driver.add("keyExpansion", [] { + auto key = toArray("d41d8cd98f00b204e9800998ecf8427e"); + std::array expectedList = { + "d41d8cd98f00b204e9800998ecf8427e", + "94317f171b31cd13f2b1c48b1e4986f5", + "ad759965b644547644f590fd5abc1608", + "cc32a9db7a76fdad3e836d50643f7b58", + "b113c398cb653e35f5e6536591d9283d", + "9427e4195f42da2caaa489493b7da174", + "4b1576fb1457acd7bef3259e858e84ea", + "124af16c061d5dbbb8ee78253d60fccf", + "42fa7b4b44e726f0fc095ed5c169a21a", + "a0c0d933e427ffc3182ea116d947030c", + "36bb2706d29cd8c5cab279d313f57adf"}; + checkKeyExpansion(key, expectedList); + }); + driver.add("addRoundKey", [] { + auto state = toState("000102030405060708090a0b0c0d0e0f"); + auto key = toKey("d41d8cd98f00b204e9800998ecf8427e"); + auto newState = addRoundKey(state, &key); + dump(newState); + uint8_t actual[16]; + vst1q_u8(actual, newState); + auto expected = toArray("d41c8eda8b05b403e1890393e0f54c71"); + for (auto k = 0; k < 16; ++k) { + expect(actual[k]) == expected[k]; + } + }); + driver.add("invShiftRowsSubBytes", [] { + auto state = toState("d41d8cd98f00b204e9800998ecf8427e"); + auto newState = invShiftRowsSubBytes(state); + dump(newState); + uint8_t actual[16]; + vst1q_u8(actual, newState); + auto expected = toArray("19e1403073def6e2eb52f08a833a3ee5"); + for (auto k = 0; k < 16; ++k) { + expect(actual[k]) == expected[k]; + } + }); + driver.add("invMixColumns", [] { + auto state = toState("d41d8cd98f00b204e9800998ecf8427e"); + auto newState = invMixColumns(state); + dump(newState); + uint8_t actual[16]; + vst1q_u8(actual, newState); + auto expected = toArray("36094dee9485dbf3ef90e46339cac71c"); + for (auto k = 0; k < 16; ++k) { + expect(actual[k]) == expected[k]; + } + }); + driver.add("eqInvCipher (test vector)", [] { + auto state = toState("69c4e0d86a7b0430d8cdb78070b4c55a"); + auto key = toArray("000102030405060708090a0b0c0d0e0f"); + struct Aes128Cbc_Key k; + std::memcpy(k.data, &key[0], 16); + struct Aes128Cbc_RoundKey roundKey; + keyExpansion(&k, &roundKey); + postKeyExpansion(&roundKey); + auto newState = eqInvCipher(state, &roundKey); + dump(newState); + uint8_t actual[16]; + vst1q_u8(actual, newState); + auto expected = toArray("00112233445566778899aabbccddeeff"); + for (auto k = 0; k < 16; ++k) { + expect(actual[k]) == expected[k]; + } + }); + driver.add("alice (1024 bytes at a time)", [] { + auto* in = std::fopen("alice.md.encrypted", "rb"); + expect(in) != nullptr; + auto* out = std::fopen("alice.md", "wb"); + expect(out) != nullptr; + auto result = decrypt<1024>(in, out); + expect(result) == 1; + std::fclose(in); + std::fclose(out); + compare("alice.md", "alice.md.decrypted"); + }); + driver.add("alice (16 bytes at a time)", [] { + auto* in = std::fopen("alice.md.encrypted", "rb"); + expect(in) != nullptr; + auto* out = std::fopen("alice.md", "wb"); + expect(out) != nullptr; + auto result = decrypt<16>(in, out); + expect(result) == 1; + std::fclose(in); + std::fclose(out); + compare("alice.md", "alice.md.decrypted"); + }); + return driver.run(); +}