From 174159512d46795ea09dda5101da99b0ca193edc Mon Sep 17 00:00:00 2001 From: David Moreno Montero Date: Sat, 12 Oct 2024 13:20:11 +0200 Subject: [PATCH] Reworked disconnect peers, now a specific function And used in: * Control router.disconnect command * When peer is removed * Tests --- src/control_socket.cpp | 10 +++++++ src/midipeer.hpp | 44 +++++++++++++++++++++++++++++++ src/midirouter.cpp | 55 ++++++++++++++++++++++++++++++++------- src/midirouter.hpp | 1 + tests/test_midirouter.cpp | 51 ++++++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 10 deletions(-) diff --git a/src/control_socket.cpp b/src/control_socket.cpp index 4554820..9a7317b 100644 --- a/src/control_socket.cpp +++ b/src/control_socket.cpp @@ -214,6 +214,16 @@ const std::vector COMMANDS{ control.router->connect(from_peer_id, to_peer_id); return "ok"; }}, + {"router.disconnect", + "Disconnects two peers at the router. Unidirectional connection.", + [](control_socket_t &control, const json_t ¶ms) { + DEBUG("Params {}", params.dump()); + peer_id_t from_peer_id = params["from"]; + peer_id_t to_peer_id = params["to"]; + DEBUG("Disconnect peers: {} -> {}", from_peer_id, to_peer_id); + control.router->disconnect(from_peer_id, to_peer_id); + return "ok"; + }}, {"connect", "Connect to a peer send params: [hostname] | [hostname, port] | [name, " "hostname, port] | {\"name\": name, \"hostname\": hostname, \"port\": " diff --git a/src/midipeer.hpp b/src/midipeer.hpp index 2904267..8c4bed6 100644 --- a/src/midipeer.hpp +++ b/src/midipeer.hpp @@ -19,6 +19,7 @@ #pragma once #include "json_fwd.hpp" +#include "rtpmidid/logger.hpp" #include "rtpmidid/utils.hpp" #include @@ -41,15 +42,58 @@ class midipeer_t : public std::enable_shared_from_this { public: std::shared_ptr router; midipeer_id_t peer_id = 0; + /// @brief statistics int packets_sent = 0; + /// @brief statistics int packets_recv = 0; midipeer_t() = default; virtual ~midipeer_t(); + /** + * @brief Returns the status of the + * + * Basic data can be get with utils::peer_status + * + * @return json_t + */ virtual json_t status() = 0; + /** + * @brief Send a midi message to the peer + * + * @param from The peer that sends the message + * @param data The midi message + */ virtual void send_midi(midipeer_id_t from, const mididata_t &) = 0; + /** + * @brief Called when the peer is connected + * + * Normally do nothing, but might need to open a file and close + * when all disconnect signas are received + */ + virtual void connected(midipeer_id_t to) { + DEBUG("Peer connected peer_id={} to={}", peer_id, to); + }; + /** + * @brief Called when the peer is disconnected + * + * Normally do nothing, but might need to open a file and close + * when all disconnect signas are received + */ + virtual void disconnected(midipeer_id_t from) { + DEBUG("Peer disconnected peer_id={} from={}", peer_id, from); + }; + /** + * @brief Command as sent by the control interface + * + * @param cmd The command + * @param data The data + * @return json_t The response + */ virtual json_t command(const std::string &cmd, const json_t &data); + /** + * @brief Get the type of the peer + */ virtual const char *get_type() const = 0; }; } // namespace rtpmididns diff --git a/src/midirouter.cpp b/src/midirouter.cpp index 59c70aa..6f0f962 100644 --- a/src/midirouter.cpp +++ b/src/midirouter.cpp @@ -49,7 +49,7 @@ uint32_t midirouter_t::add_peer(std::shared_ptr peer) { peer, {}, }; - INFO("Added peer {}", peer_id); + INFO("Added peer type={} peer_id={}", peer->get_type(), peer_id); return peer_id; } @@ -86,13 +86,22 @@ void midirouter_t::peer_connection_loop( } void midirouter_t::remove_peer(peer_id_t peer_id) { - auto removed = peers.erase(peer_id); + INFO("Remove peer_id={}", peer_id); + auto toremove = get_peer_by_id(peer_id); + + // Find all the peers that are connected to this peer and disconnect them for (auto &peer : peers) { - auto &send_to = peer.second.send_to; - auto I = std::find(send_to.begin(), send_to.end(), peer_id); - if (I != send_to.end()) - send_to.erase(I); + // need to copy the send_to vector to avoid iterator invalidation + auto send_to_copy = peer.second.send_to; + for (auto send_to_id : send_to_copy) { + if (send_to_id == peer_id) { + disconnect(peer.first, peer_id); + } + disconnect(peer_id, peer.first); + } } + + auto removed = peers.erase(peer_id); if (removed) INFO("Removed peer {}", peer_id); } @@ -125,14 +134,40 @@ void midirouter_t::send_midi(peer_id_t from, peer_id_t to, } void midirouter_t::connect(peer_id_t from, peer_id_t to) { - auto send_peer = get_peerdata_by_id(from); - auto recv_peer = get_peerdata_by_id(to); - if (!send_peer || !recv_peer) { + auto from_peer = get_peerdata_by_id(from); + auto to_peer = get_peerdata_by_id(to); + if (!from_peer || !to_peer) { WARNING("Sending to unkown peer {} -> {}", from, to); return; } - send_peer->send_to.push_back(to); + from_peer->send_to.push_back(to); + + from_peer->peer->connected(to); + to_peer->peer->connected(from); + + INFO("Connect {} -> {}", from, to); +} + +void midirouter_t::disconnect(peer_id_t from, peer_id_t to) { + auto from_peer = get_peerdata_by_id(from); + auto to_peer = get_peerdata_by_id(to); + if (!from_peer || !to_peer) { + WARNING("Sending to unkown peer {} -> {}", from, to); + return; + } + + for (auto send_to_id : from_peer->send_to) { + if (send_to_id == to) { + from_peer->send_to.erase( + std::remove(from_peer->send_to.begin(), from_peer->send_to.end(), to), + from_peer->send_to.end()); + from_peer->peer->disconnected(to); + to_peer->peer->disconnected(from); + } + } + + INFO("Disconnect {} -> {}", from, to); } json_t midirouter_t::status() { diff --git a/src/midirouter.hpp b/src/midirouter.hpp index c0c9783..8e28433 100644 --- a/src/midirouter.hpp +++ b/src/midirouter.hpp @@ -53,6 +53,7 @@ class midirouter_t : public std::enable_shared_from_this { void remove_peer(peer_id_t); void connect(peer_id_t from, peer_id_t to); + void disconnect(peer_id_t from, peer_id_t to); void peer_connection_loop(peer_id_t peer_id, std::function)>); diff --git a/tests/test_midirouter.cpp b/tests/test_midirouter.cpp index a7b1764..58a4ef8 100644 --- a/tests/test_midirouter.cpp +++ b/tests/test_midirouter.cpp @@ -190,11 +190,62 @@ void test_midirouter_for_each_peer() { ASSERT_EQUAL(count, 1); } +class test_signal_t : public rtpmididns::midipeer_t { +public: + int connections = 0; + void send_midi(rtpmididns::midipeer_id_t from, + const rtpmididns::mididata_t &) override {} + const char *get_type() const override { return "test_signal_t"; } + rtpmididns::json_t status() override { return rtpmididns::json_t{}; } + void connected(rtpmididns::midipeer_id_t to) override { + connections++; + rtpmididns::midipeer_t::connected(to); + } + void disconnected(rtpmididns::midipeer_id_t from) override { + connections--; + rtpmididns::midipeer_t::disconnected(from); + } +}; + +void test_connect_disconnect_signals() { + auto router = std::make_shared(); + auto peera = std::make_shared(); + auto peerb = std::make_shared(); + + router->add_peer(peera); + router->add_peer(peerb); + + DEBUG("One connection"); + router->connect(peera->peer_id, peerb->peer_id); + ASSERT_EQUAL(peera->connections, 1); + ASSERT_EQUAL(peerb->connections, 1); + ASSERT_EQUAL(router->peers[peera->peer_id].send_to.size(), 1); + ASSERT_EQUAL(router->peers[peera->peer_id].send_to[0], peerb->peer_id); + ASSERT_EQUAL(router->peers[peerb->peer_id].send_to.size(), 0); + + DEBUG("Remove one connection"); + router->disconnect(peera->peer_id, peerb->peer_id); + ASSERT_EQUAL(peera->connections, 0); + ASSERT_EQUAL(peerb->connections, 0); + + DEBUG("Bi-directional connections"); + router->connect(peera->peer_id, peerb->peer_id); + router->connect(peerb->peer_id, peera->peer_id); + ASSERT_EQUAL(peera->connections, 2); + ASSERT_EQUAL(peerb->connections, 2); + + DEBUG("Remove the peer, removes all connections"); + router->remove_peer(peera->peer_id); + ASSERT_EQUAL(peera->connections, 0); + ASSERT_EQUAL(peerb->connections, 0); +} + int main(int argc, char **argv) { test_case_t testcase{ TEST(test_basic_midirouter), TEST(test_midirouter_from_alsa), TEST(test_midirouter_for_each_peer), + TEST(test_connect_disconnect_signals), }; testcase.run(argc, argv);