diff --git a/README.md b/README.md index 35f1025..916495b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,22 @@ Since the library is header only, you only need the library on the build system. You can also include the library as a submodule in your project. +### Running the tests + +Libmav uses [doctest](https://github.com/doctest/doctest/) and [gcovr](https://github.com/gcovr/gcovr/). + +To run the tests, build the library, then run the test executable. Test results will be output to console. + +```bash +mkdir build && cd build && cmake .. && make tests +./tests/tests +``` + +To test coverage, simple invoke the coverage tool from the root directory. +```bash +gcovr +``` + ## Getting started ### Loading a message set diff --git a/gcovr.cfg b/gcovr.cfg index 1440025..4e242a4 100644 --- a/gcovr.cfg +++ b/gcovr.cfg @@ -1,3 +1,4 @@ exclude-throw-branches = yes filter = include/mav -exclude = include/mav/rapidxml/* \ No newline at end of file +exclude = include/mav/rapidxml/* +exclude = include/mav/picosha2/* \ No newline at end of file diff --git a/include/mav/Message.h b/include/mav/Message.h index 1ae0321..bbeaba0 100644 --- a/include/mav/Message.h +++ b/include/mav/Message.h @@ -41,6 +41,7 @@ #include #include "MessageDefinition.h" #include "utils.h" +#include "picosha2/picosha2.h" namespace mav { @@ -153,6 +154,29 @@ namespace mav { throw std::runtime_error("Unknown base type"); // should never happen } + uint64_t _computeSignatureHash48(const std::array& key) const { + // signature = sha256_48(secret_key + header + payload + CRC + link-ID + timestamp) + picosha2::hash256_one_by_one hasher; + // secret_key + hasher.process(key.begin(), key.begin() + MessageDefinition::KEY_SIZE); + // header + payload + CRC + hasher.process(_backing_memory.begin(), _backing_memory.begin() + + MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE); + // link-ID + const uint8_t linkId = signature().linkId(); + hasher.process(&linkId, &linkId + MessageDefinition::SIGNATURE_LINK_ID_SIZE); + // timestamp + const uint64_t timestamp = signature().timestamp(); + std::array timestampSerialized; + serialize(timestamp, timestampSerialized.begin()); + hasher.process(timestampSerialized.begin(), timestampSerialized.begin() + MessageDefinition::SIGNATURE_TIMESTAMP_SIZE); + + hasher.finish(); + std::vector hash(picosha2::k_digest_size); + hasher.get_hash_bytes(hash.begin(), hash.end()); + return deserialize(hash.data(), MessageDefinition::SIGNATURE_SIGNATURE_SIZE); + } + public: static inline Message _instantiateFromMemory(const MessageDefinition &definition, ConnectionPartner source_partner, @@ -222,6 +246,20 @@ namespace mav { return Header(_backing_memory.data()); } + [[nodiscard]] const Signature signature() const { + if (!isFinalized()) { + throw std::runtime_error("Unable to parse unfinalized message."); + } + return Signature(&_backing_memory[MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE]); + } + + [[nodiscard]] Signature signature() { + if (!isFinalized()) { + throw std::runtime_error("Unable to parse unfinalized message."); + } + return Signature(&_backing_memory[MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE]); + } + [[nodiscard]] const ConnectionPartner& source() const { return _source_partner; } @@ -440,11 +478,23 @@ namespace mav { return ss.str(); } + [[nodiscard]] bool validate(const std::array& key) const { + return signature().signature() == _computeSignatureHash48(key); + } + [[nodiscard]] uint32_t finalize(uint8_t seq, const Identifier &sender) { + static const std::array null_key = {}; + return finalize(seq, sender, null_key, 0, 0); + } + + [[nodiscard]] uint32_t finalize(uint8_t seq, const Identifier &sender, + const std::array& key, + const uint64_t& timestamp, const uint8_t linkId = 0) { if (isFinalized()) { _unFinalize(); } + bool sign = (timestamp > 0); auto last_nonzero = std::find_if(_backing_memory.rend() - MessageDefinition::HEADER_SIZE - _message_definition->maxPayloadSize(), _backing_memory.rend(), [](const auto &v) { @@ -457,7 +507,7 @@ namespace mav { header().magic() = 0xFD; header().len() = static_cast(payload_size); - header().incompatFlags() = 0; + header().incompatFlags() = sign ? 0x01 : 0x00; header().compatFlags() = 0; header().seq() = seq; if (header().systemId() == 0) { @@ -475,7 +525,15 @@ namespace mav { _crc_offset = MessageDefinition::HEADER_SIZE + payload_size; serialize(crc.crc16(), _backing_memory.data() + _crc_offset); - return MessageDefinition::HEADER_SIZE + payload_size + MessageDefinition::CHECKSUM_SIZE; + int signature_size = 0; + if (sign) { + signature().linkId() = linkId; + signature().timestamp() = timestamp; + signature().signature() = _computeSignatureHash48(key); + signature_size = MessageDefinition::SIGNATURE_SIZE; + } + + return MessageDefinition::HEADER_SIZE + payload_size + MessageDefinition::CHECKSUM_SIZE + signature_size; } [[nodiscard]] const uint8_t* data() const { diff --git a/include/mav/MessageDefinition.h b/include/mav/MessageDefinition.h index 07e7eb8..81396a8 100644 --- a/include/mav/MessageDefinition.h +++ b/include/mav/MessageDefinition.h @@ -179,6 +179,10 @@ namespace mav { return _backing_memory[2]; } + [[nodiscard]] inline bool isSigned() const { + return incompatFlags() & 0x01; + } + inline uint8_t& compatFlags() { return _backing_memory[3]; } @@ -224,6 +228,61 @@ namespace mav { } }; + template + class Signature { + private: + BackingMemoryPointerType _backing_memory; + + // Both timestamp and signature use 6-byte fields + class _SignatureField { + private: + BackingMemoryPointerType _ptr; + public: + explicit _SignatureField(BackingMemoryPointerType ptr) : _ptr(ptr) {} + + operator uint64_t() const { + return static_cast((*static_cast(static_cast(_ptr))) & 0xFFFFFFFFFFFF); + } + + _SignatureField& operator=(uint64_t v) { + _ptr[0] = static_cast(v & 0xFF); + _ptr[1] = static_cast((v >> 8) & 0xFF); + _ptr[2] = static_cast((v >> 16) & 0xFF); + _ptr[3] = static_cast((v >> 24) & 0xFF); + _ptr[4] = static_cast((v >> 32) & 0xFF); + _ptr[5] = static_cast((v >> 40) & 0xFF); + return *this; + } + }; + + public: + explicit Signature(BackingMemoryPointerType backing_memory) : _backing_memory(backing_memory) {} + + inline uint8_t& linkId(){ + return _backing_memory[0]; + } + + [[nodiscard]] inline uint8_t linkId() const { + return _backing_memory[0]; + } + + inline _SignatureField timestamp() { + return _SignatureField(_backing_memory + 1); + } + + [[nodiscard]] inline const _SignatureField timestamp() const { + return _SignatureField(_backing_memory + 1); + } + + inline _SignatureField signature() { + return _SignatureField(_backing_memory + 7); + } + + [[nodiscard]] inline const _SignatureField signature() const { + return _SignatureField(_backing_memory + 7); + } + }; + class FieldType { public: enum class BaseType { @@ -308,8 +367,12 @@ namespace mav { static constexpr int MAX_PAYLOAD_SIZE = 255; static constexpr int HEADER_SIZE = 10; static constexpr int CHECKSUM_SIZE = 2; - static constexpr int SIGNATURE_SIZE = 13; + static constexpr int SIGNATURE_LINK_ID_SIZE = 1; + static constexpr int SIGNATURE_TIMESTAMP_SIZE = 6; + static constexpr int SIGNATURE_SIGNATURE_SIZE = 6; + static constexpr int SIGNATURE_SIZE = SIGNATURE_LINK_ID_SIZE + SIGNATURE_TIMESTAMP_SIZE + SIGNATURE_SIGNATURE_SIZE; static constexpr int MAX_MESSAGE_SIZE = MAX_PAYLOAD_SIZE + HEADER_SIZE + CHECKSUM_SIZE + SIGNATURE_SIZE; + static constexpr int KEY_SIZE = 32; [[nodiscard]] inline const std::string& name() const { return _name; diff --git a/include/mav/Network.h b/include/mav/Network.h index 5988b42..112724d 100644 --- a/include/mav/Network.h +++ b/include/mav/Network.h @@ -43,6 +43,7 @@ #include #include #include +#include #include #include #include "Connection.h" @@ -117,7 +118,7 @@ namespace mav { backing_memory[0] = 0xFD; _interface.receive(backing_memory.data() + 1, MessageDefinition::HEADER_SIZE -1); Header header{backing_memory.data()}; - const bool message_is_signed = header.incompatFlags() & 0x01; + const bool message_is_signed = header.isSigned(); const int wire_length = MessageDefinition::HEADER_SIZE + header.len() + MessageDefinition::CHECKSUM_SIZE + (message_is_signed ? MessageDefinition::SIGNATURE_SIZE : 0); auto partner = _interface.receive(backing_memory.data() + MessageDefinition::HEADER_SIZE, @@ -145,6 +146,10 @@ namespace mav { } }; + static uint64_t _get_timestamp_function_default() { + const auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast(now.time_since_epoch()).count(); + } class NetworkRuntime { private: @@ -158,6 +163,9 @@ namespace mav { std::mutex _heartbeat_message_mutex; StreamParser _parser; Identifier _own_id; + bool _sign; + std::array _key; + std::function _get_timestamp_function; std::mutex _connections_mutex; std::mutex _send_mutex; std::unordered_map&)> _on_connection_lost; void _sendMessage(Message &message, const ConnectionPartner &partner) { - int wire_length = static_cast(message.finalize(_seq++, _own_id)); + int wire_length; + if (_sign) { + wire_length = static_cast(message.finalize(_seq++, _own_id, _key, _get_timestamp_function())); + } else { + wire_length = static_cast(message.finalize(_seq++, _own_id)); + } std::unique_lock lock(_send_mutex); _interface.send(message.data(), wire_length, partner); } @@ -327,6 +340,7 @@ namespace mav { std::function&)> on_connection_lost = {}) : _interface(interface), _message_set(message_set), _parser(_message_set, _interface), _own_id(own_id), + _sign(false), _get_timestamp_function(_get_timestamp_function_default), _on_connection(std::move(on_connection)), _on_connection_lost(std::move(on_connection_lost)) { _receive_thread = std::thread{ @@ -409,6 +423,17 @@ namespace mav { _sendMessage(message, {}); } + void enableMessageSigning(std::array key, + std::function timestampFunction = _get_timestamp_function_default) { + _sign = true; + _key = key; + _get_timestamp_function = timestampFunction; + } + + void disableMessageSigning() { + _sign = false; + } + void stop() { _interface.close(); _should_terminate.store(true); diff --git a/include/mav/picosha2/license.txt b/include/mav/picosha2/license.txt new file mode 100644 index 0000000..b6658bb --- /dev/null +++ b/include/mav/picosha2/license.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 okdshin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/include/mav/picosha2/picosha2.h b/include/mav/picosha2/picosha2.h new file mode 100644 index 0000000..a921736 --- /dev/null +++ b/include/mav/picosha2/picosha2.h @@ -0,0 +1,377 @@ +/* +The MIT License (MIT) + +Copyright (C) 2017 okdshin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef PICOSHA2_H +#define PICOSHA2_H +// picosha2:20140213 + +#ifndef PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR +#define PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR \ + 1048576 //=1024*1024: default is 1MB memory +#endif + +#include +#include +#include +#include +#include +#include +namespace picosha2 { +typedef unsigned long word_t; +typedef unsigned char byte_t; + +static const size_t k_digest_size = 32; + +namespace detail { +inline byte_t mask_8bit(byte_t x) { return x & 0xff; } + +inline word_t mask_32bit(word_t x) { return x & 0xffffffff; } + +const word_t add_constant[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, + 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; + +const word_t initial_message_digest[8] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, + 0xa54ff53a, 0x510e527f, 0x9b05688c, + 0x1f83d9ab, 0x5be0cd19}; + +inline word_t ch(word_t x, word_t y, word_t z) { return (x & y) ^ ((~x) & z); } + +inline word_t maj(word_t x, word_t y, word_t z) { + return (x & y) ^ (x & z) ^ (y & z); +} + +inline word_t rotr(word_t x, std::size_t n) { + assert(n < 32); + return mask_32bit((x >> n) | (x << (32 - n))); +} + +inline word_t bsig0(word_t x) { return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); } + +inline word_t bsig1(word_t x) { return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); } + +inline word_t shr(word_t x, std::size_t n) { + assert(n < 32); + return x >> n; +} + +inline word_t ssig0(word_t x) { return rotr(x, 7) ^ rotr(x, 18) ^ shr(x, 3); } + +inline word_t ssig1(word_t x) { return rotr(x, 17) ^ rotr(x, 19) ^ shr(x, 10); } + +template +void hash256_block(RaIter1 message_digest, RaIter2 first, RaIter2 last) { + assert(first + 64 == last); + static_cast(last); // for avoiding unused-variable warning + word_t w[64]; + std::fill(w, w + 64, word_t(0)); + for (std::size_t i = 0; i < 16; ++i) { + w[i] = (static_cast(mask_8bit(*(first + i * 4))) << 24) | + (static_cast(mask_8bit(*(first + i * 4 + 1))) << 16) | + (static_cast(mask_8bit(*(first + i * 4 + 2))) << 8) | + (static_cast(mask_8bit(*(first + i * 4 + 3)))); + } + for (std::size_t i = 16; i < 64; ++i) { + w[i] = mask_32bit(ssig1(w[i - 2]) + w[i - 7] + ssig0(w[i - 15]) + + w[i - 16]); + } + + word_t a = *message_digest; + word_t b = *(message_digest + 1); + word_t c = *(message_digest + 2); + word_t d = *(message_digest + 3); + word_t e = *(message_digest + 4); + word_t f = *(message_digest + 5); + word_t g = *(message_digest + 6); + word_t h = *(message_digest + 7); + + for (std::size_t i = 0; i < 64; ++i) { + word_t temp1 = h + bsig1(e) + ch(e, f, g) + add_constant[i] + w[i]; + word_t temp2 = bsig0(a) + maj(a, b, c); + h = g; + g = f; + f = e; + e = mask_32bit(d + temp1); + d = c; + c = b; + b = a; + a = mask_32bit(temp1 + temp2); + } + *message_digest += a; + *(message_digest + 1) += b; + *(message_digest + 2) += c; + *(message_digest + 3) += d; + *(message_digest + 4) += e; + *(message_digest + 5) += f; + *(message_digest + 6) += g; + *(message_digest + 7) += h; + for (std::size_t i = 0; i < 8; ++i) { + *(message_digest + i) = mask_32bit(*(message_digest + i)); + } +} + +} // namespace detail + +template +void output_hex(InIter first, InIter last, std::ostream& os) { + os.setf(std::ios::hex, std::ios::basefield); + while (first != last) { + os.width(2); + os.fill('0'); + os << static_cast(*first); + ++first; + } + os.setf(std::ios::dec, std::ios::basefield); +} + +template +void bytes_to_hex_string(InIter first, InIter last, std::string& hex_str) { + std::ostringstream oss; + output_hex(first, last, oss); + hex_str.assign(oss.str()); +} + +template +void bytes_to_hex_string(const InContainer& bytes, std::string& hex_str) { + bytes_to_hex_string(bytes.begin(), bytes.end(), hex_str); +} + +template +std::string bytes_to_hex_string(InIter first, InIter last) { + std::string hex_str; + bytes_to_hex_string(first, last, hex_str); + return hex_str; +} + +template +std::string bytes_to_hex_string(const InContainer& bytes) { + std::string hex_str; + bytes_to_hex_string(bytes, hex_str); + return hex_str; +} + +class hash256_one_by_one { + public: + hash256_one_by_one() { init(); } + + void init() { + buffer_.clear(); + std::fill(data_length_digits_, data_length_digits_ + 4, word_t(0)); + std::copy(detail::initial_message_digest, + detail::initial_message_digest + 8, h_); + } + + template + void process(RaIter first, RaIter last) { + add_to_data_length(static_cast(std::distance(first, last))); + std::copy(first, last, std::back_inserter(buffer_)); + std::size_t i = 0; + for (; i + 64 <= buffer_.size(); i += 64) { + detail::hash256_block(h_, buffer_.begin() + i, + buffer_.begin() + i + 64); + } + buffer_.erase(buffer_.begin(), buffer_.begin() + i); + } + + void finish() { + byte_t temp[64]; + std::fill(temp, temp + 64, byte_t(0)); + std::size_t remains = buffer_.size(); + std::copy(buffer_.begin(), buffer_.end(), temp); + temp[remains] = 0x80; + + if (remains > 55) { + std::fill(temp + remains + 1, temp + 64, byte_t(0)); + detail::hash256_block(h_, temp, temp + 64); + std::fill(temp, temp + 64 - 4, byte_t(0)); + } else { + std::fill(temp + remains + 1, temp + 64 - 4, byte_t(0)); + } + + write_data_bit_length(&(temp[56])); + detail::hash256_block(h_, temp, temp + 64); + } + + template + void get_hash_bytes(OutIter first, OutIter last) const { + for (const word_t* iter = h_; iter != h_ + 8; ++iter) { + for (std::size_t i = 0; i < 4 && first != last; ++i) { + *(first++) = detail::mask_8bit( + static_cast((*iter >> (24 - 8 * i)))); + } + } + } + + private: + void add_to_data_length(word_t n) { + word_t carry = 0; + data_length_digits_[0] += n; + for (std::size_t i = 0; i < 4; ++i) { + data_length_digits_[i] += carry; + if (data_length_digits_[i] >= 65536u) { + carry = data_length_digits_[i] >> 16; + data_length_digits_[i] &= 65535u; + } else { + break; + } + } + } + void write_data_bit_length(byte_t* begin) { + word_t data_bit_length_digits[4]; + std::copy(data_length_digits_, data_length_digits_ + 4, + data_bit_length_digits); + + // convert byte length to bit length (multiply 8 or shift 3 times left) + word_t carry = 0; + for (std::size_t i = 0; i < 4; ++i) { + word_t before_val = data_bit_length_digits[i]; + data_bit_length_digits[i] <<= 3; + data_bit_length_digits[i] |= carry; + data_bit_length_digits[i] &= 65535u; + carry = (before_val >> (16 - 3)) & 65535u; + } + + // write data_bit_length + for (int i = 3; i >= 0; --i) { + (*begin++) = static_cast(data_bit_length_digits[i] >> 8); + (*begin++) = static_cast(data_bit_length_digits[i]); + } + } + std::vector buffer_; + word_t data_length_digits_[4]; // as 64bit integer (16bit x 4 integer) + word_t h_[8]; +}; + +inline void get_hash_hex_string(const hash256_one_by_one& hasher, + std::string& hex_str) { + byte_t hash[k_digest_size]; + hasher.get_hash_bytes(hash, hash + k_digest_size); + return bytes_to_hex_string(hash, hash + k_digest_size, hex_str); +} + +inline std::string get_hash_hex_string(const hash256_one_by_one& hasher) { + std::string hex_str; + get_hash_hex_string(hasher, hex_str); + return hex_str; +} + +namespace impl { +template +void hash256_impl(RaIter first, RaIter last, OutIter first2, OutIter last2, int, + std::random_access_iterator_tag) { + hash256_one_by_one hasher; + // hasher.init(); + hasher.process(first, last); + hasher.finish(); + hasher.get_hash_bytes(first2, last2); +} + +template +void hash256_impl(InputIter first, InputIter last, OutIter first2, + OutIter last2, int buffer_size, std::input_iterator_tag) { + std::vector buffer(buffer_size); + hash256_one_by_one hasher; + // hasher.init(); + while (first != last) { + int size = buffer_size; + for (int i = 0; i != buffer_size; ++i, ++first) { + if (first == last) { + size = i; + break; + } + buffer[i] = *first; + } + hasher.process(buffer.begin(), buffer.begin() + size); + } + hasher.finish(); + hasher.get_hash_bytes(first2, last2); +} +} + +template +void hash256(InIter first, InIter last, OutIter first2, OutIter last2, + int buffer_size = PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR) { + picosha2::impl::hash256_impl( + first, last, first2, last2, buffer_size, + typename std::iterator_traits::iterator_category()); +} + +template +void hash256(InIter first, InIter last, OutContainer& dst) { + hash256(first, last, dst.begin(), dst.end()); +} + +template +void hash256(const InContainer& src, OutIter first, OutIter last) { + hash256(src.begin(), src.end(), first, last); +} + +template +void hash256(const InContainer& src, OutContainer& dst) { + hash256(src.begin(), src.end(), dst.begin(), dst.end()); +} + +template +void hash256_hex_string(InIter first, InIter last, std::string& hex_str) { + byte_t hashed[k_digest_size]; + hash256(first, last, hashed, hashed + k_digest_size); + std::ostringstream oss; + output_hex(hashed, hashed + k_digest_size, oss); + hex_str.assign(oss.str()); +} + +template +std::string hash256_hex_string(InIter first, InIter last) { + std::string hex_str; + hash256_hex_string(first, last, hex_str); + return hex_str; +} + +inline void hash256_hex_string(const std::string& src, std::string& hex_str) { + hash256_hex_string(src.begin(), src.end(), hex_str); +} + +template +void hash256_hex_string(const InContainer& src, std::string& hex_str) { + hash256_hex_string(src.begin(), src.end(), hex_str); +} + +template +std::string hash256_hex_string(const InContainer& src) { + return hash256_hex_string(src.begin(), src.end()); +} +templatevoid hash256(std::ifstream& f, OutIter first, OutIter last){ + hash256(std::istreambuf_iterator(f), std::istreambuf_iterator(), first,last); + +} +}// namespace picosha2 +#endif // PICOSHA2_H diff --git a/sonar-project.properties b/sonar-project.properties index 18cc5f1..4d9c71f 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -5,9 +5,9 @@ sonar.projectName = libmav sonar.sources = include/,tests/ -sonar.exclusions = include/mav/rapidxml/,tests/doctest.h +sonar.exclusions = include/mav/rapidxml/*,tests/doctest.h,include/mav/picosha2/* -sonar.coverage.exclusions = tests/**/*,include/mav/rapidxml/ +sonar.coverage.exclusions = tests/**/*,include/mav/rapidxml/*,include/mav/picosha2/* sonar.cpd.exclusions = tests/**/* sonar.sourceEncoding = UTF-8 diff --git a/tests/Message.cpp b/tests/Message.cpp index 3edf1be..166a9f0 100644 --- a/tests/Message.cpp +++ b/tests/Message.cpp @@ -378,5 +378,25 @@ TEST_CASE("Message set creation") { "Message ID 9915 (BIG_MESSAGE) \n char_arr_field: \"Hello World!\"\n double_field: 9\n float_arr_field: 1, 2, 3\n float_field: 10\n int16_field: -4\n int32_arr_field: 4, 5, 6\n int32_field: -6\n int64_field: 8\n int8_field: -2\n uint16_field: 3\n uint32_field: 5\n uint64_field: 7\n uint8_field: 1\n"); } + SUBCASE("Sign a packet") { + + std::array key; + for (int i = 0 ; i < 32; i++) key[i] = i; + + uint64_t timestamp = 770479200; + + // Attempt to access signature before signed (const & non-const versions) + const auto const_message = message_set.create("UINT8_ONLY_MESSAGE"); + CHECK_THROWS_AS(message.signature(), std::runtime_error); + CHECK_THROWS_AS(const_message.signature(), std::runtime_error); + + uint32_t wire_size = message.finalize(5, {6, 7}, key, timestamp); + + CHECK_EQ(wire_size, 26); + CHECK(message.header().isSigned()); + CHECK_NE(message.signature().signature(), 0); + CHECK_EQ(message.signature().timestamp(), timestamp); + CHECK(message.validate(key)); + } } diff --git a/tests/Network.cpp b/tests/Network.cpp index c0ad348..18b4e03 100644 --- a/tests/Network.cpp +++ b/tests/Network.cpp @@ -90,6 +90,9 @@ class DummyInterface : public NetworkInterface { } }; +uint64_t getTimestamp() { + return 770479200; +} TEST_CASE("Create network runtime") { @@ -122,6 +125,9 @@ TEST_CASE("Create network runtime") { DummyInterface interface; NetworkRuntime network({253, 1}, message_set, interface); + std::array key; + for (int i = 0 ; i < 32; i++) key[i] = i; + // send a heartbeat message, to establish a connection interface.addToReceiveQueue("\xfd\x09\x00\x00\x00\xfd\x01\x00\x00\x00\x04\x00\x00\x00\x01\x02\x03\x05\x06\x77\x53"s, interface_partner); auto connection = network.awaitConnection(); @@ -213,4 +219,50 @@ TEST_CASE("Create network runtime") { CHECK_EQ(message.name(), "HEARTBEAT"); CHECK(connection->alive()); } + + SUBCASE("Enable message signing") { + auto message = message_set.create("TEST_MESSAGE")({ + {"value", 42}, + {"text", "Hello World!"} + }); + interface.reset(); + network.enableMessageSigning(key); + connection->send(message); + CHECK(message.header().isSigned()); + // don't check anything after link_id in signature as the timestamp is dependent on current time + bool found = (interface.sendSpongeContains( + "\xfd\x10\x01\x00\x00\xfd\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\xfd\x33\x00"s, + interface_partner)); + CHECK(found); + } + + SUBCASE("Enable message signing with custom timestamp function") { + auto message = message_set.create("TEST_MESSAGE")({ + {"value", 42}, + {"text", "Hello World!"} + }); + interface.reset(); + network.enableMessageSigning(key, getTimestamp); + connection->send(message); + CHECK(message.header().isSigned()); + bool found = (interface.sendSpongeContains( + "\xfd\x10\x01\x00\x00\xfd\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\xfd\x33\x00\x60\x94\xec\x2d\x00\x00\x7b\xab\xfa\x1a\xed\xf9"s, + interface_partner)); + CHECK(found); + } + + SUBCASE("Disable message signing") { + auto message = message_set.create("TEST_MESSAGE")({ + {"value", 42}, + {"text", "Hello World!"} + }); + interface.reset(); + network.disableMessageSigning(); + connection->send(message); + CHECK(!message.header().isSigned()); + bool found = (interface.sendSpongeContains( + "\xfd\x10\x00\x00\x00\xfd\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x86\x37"s, + interface_partner)); + CHECK(found); + } }