diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 64e1560..6b26864 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: sonarcloud-build-wrapper : build-wrapper-linux-x86-64 - runner: ubuntu-20.04 sonarcloud-build-wrapper : build-wrapper-linux-x86-64 - - runner: macos-latest + - runner: macos-13 sonarcloud-build-wrapper : build-wrapper-macosx-x86 runs-on: ${{ matrix.platform.runner }} steps: diff --git a/include/mav/Connection.h b/include/mav/Connection.h index 170232b..0254c6c 100644 --- a/include/mav/Connection.h +++ b/include/mav/Connection.h @@ -192,8 +192,8 @@ namespace mav { return !_underlying_network_fault && (millis() - _last_received_ms < CONNECTION_TIMEOUT); } - template - CallbackHandle addMessageCallback(const T &on_message, const E &on_error) { + CallbackHandle addMessageCallback(const std::function &on_message, + const std::function &on_error) { std::scoped_lock lock(_message_callback_mtx); CallbackHandle handle = _next_handle; _message_callbacks[handle] = FunctionCallback{on_message, on_error}; @@ -201,9 +201,32 @@ namespace mav { return handle; } - template - CallbackHandle addMessageCallback(const T &on_message) { - return addMessageCallback(on_message, nullptr); + CallbackHandle addMessageCallback(const std::function &on_message) { + return addMessageCallback(on_message, std::function{}); + } + + CallbackHandle addMessageCallback(const std::function &selector, + const std::function &on_message, + const std::function &on_error) { + return addMessageCallback([selector, on_message](const Message &message) { + if (selector(message)) { + on_message(message); + } + }, on_error); + } + + CallbackHandle addMessageCallback(int message_id, const std::function &on_message, + int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) { + return addMessageCallback([message_id, source_id, component_id](const Message &message) { + return message.id() == message_id && + (source_id == mav::ANY_ID || message.header().systemId() == source_id) && + (component_id == mav::ANY_ID || message.header().componentId() == component_id); + }, on_message, std::function{}); + } + + CallbackHandle addMessageCallback(const std::string &message_name, const std::function &on_message, + int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) { + return addMessageCallback(_message_set.idForMessage(message_name), on_message, source_id, component_id); } void removeMessageCallback(CallbackHandle handle) { diff --git a/tests/Network.cpp b/tests/Network.cpp index 381f43d..7db8428 100644 --- a/tests/Network.cpp +++ b/tests/Network.cpp @@ -94,7 +94,7 @@ uint64_t getTimestamp() { return 770479200; } -TEST_CASE("Create network runtime") { +TEST_CASE("Network runtime") { MessageSet message_set; message_set.addFromXMLString(R"( @@ -155,6 +155,24 @@ TEST_CASE("Create network runtime") { CHECK_EQ(message.get("text"), "Hello World!"); } + SUBCASE("Selects correct message for specific message id, system id, component id") { + interface.reset(); + auto expectation = connection->expect("TEST_MESSAGE", 1, 1); + // message with wrong system id + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x02\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\xa0\xcb"s, interface_partner); + // message with wrong component id + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x01\x02\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\xe2\x61"s, interface_partner); + // message with wrong message id + interface.addToReceiveQueue("\xfd\x09\x00\x00\x00\xfd\x01\x00\x00\x00\x04\x00\x00\x00\x01\x02\x03\x05\x06\x77\x53"s, interface_partner); + // message with correct system id and component id and message id + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x01\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x56\x38"s, interface_partner); + auto message = connection->receive(expectation); + // we should only have received the last message + CHECK_EQ(message.name(), "TEST_MESSAGE"); + CHECK_EQ(message.header().systemId(), 1); + CHECK_EQ(message.header().componentId(), 1); + } + SUBCASE("Message sent twice before receive") { interface.reset(); @@ -308,4 +326,58 @@ TEST_CASE("Create network runtime") { connection->receive("HEARTBEAT"); CHECK_EQ(connection->callbackCount(), 0); } + + SUBCASE("Message callback for specific message is called when message arrives") { + interface.reset(); + std::promise callback_called_promise; + auto callback_called_future = callback_called_promise.get_future(); + + connection->addMessageCallback("TEST_MESSAGE", [&callback_called_promise](const Message &message) { + if (message.name() == "TEST_MESSAGE") { + callback_called_promise.set_value(); + } + }); + + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner); + CHECK((callback_called_future.wait_for(std::chrono::seconds(2)) != std::future_status::timeout)); + connection->removeAllCallbacks(); + } + + SUBCASE("Can specify message callbacks for message id, source system id and source component id") { + interface.reset(); + + // these should not get called + connection->addMessageCallback(9916, [](const Message &message) { + FAIL("This callback should not be called"); + }, 1, 2); + + connection->addMessageCallback(9916, [](const Message &message) { + FAIL("This callback should not be called"); + }, 2, 1); + + connection->addMessageCallback(9917, [](const Message &message) { + FAIL("This callback should not be called"); + }, 1, 1); + + // run a send-receive twice - if we succeed the second time around, we know for sure that the FAIL were not + // called from the first time around, since there is only a single receive thread. + for (int i=0; i<2; i++) { + std::promise callback_called_promise; + auto callback_called_future = callback_called_promise.get_future(); + + + // this should get called + auto cb = connection->addMessageCallback(9916, [&callback_called_promise](const Message &message) { + if (message.name() == "TEST_MESSAGE") { + callback_called_promise.set_value(); + } + }, 1, 1); + + // message id is 9916, source system id is 1, source component id is 1 + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x01\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x56\x38"s, interface_partner); + CHECK((callback_called_future.wait_for(std::chrono::seconds(2)) != std::future_status::timeout)); + connection->removeMessageCallback(cb); + } + connection->removeAllCallbacks(); + } }