Skip to content

Commit

Permalink
RTC-15334: Fix crash caused by a race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
RicardoMDomingues committed Jan 9, 2025
1 parent b6af4db commit 355922d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 24 deletions.
2 changes: 1 addition & 1 deletion bridge/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ void Bridge::initialize(std::shared_ptr<transport::EndpointFactory> endpointFact
return;
}

_probeServer = std::make_unique<transport::ProbeServer>(_iceConfig, _config);
_probeServer = std::make_unique<transport::ProbeServer>(_iceConfig, _config, *_rtJobManager);

const auto credentials = _probeServer->getCredentials();

Expand Down
70 changes: 53 additions & 17 deletions transport/ProbeServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
#include "transport/ice/IceSession.h"
#include "transport/ice/Stun.h"
#include "utils/ContainerAlgorithms.h"
#include "utils/Function.h"
#include "utils/Time.h"
#include <unistd.h>

namespace transport
{
ProbeServer::ProbeServer(const ice::IceConfig& iceConfig, const config::Config& config)
ProbeServer::ProbeServer(const ice::IceConfig& iceConfig,
const config::Config& config,
jobmanager::JobManager& jobmanager)
: _iceConfig(iceConfig),
_config(config),
_jobQueue(jobmanager, 1024),
_queue(1024),
_maintenanceThreadIsRunning(true),
_maintenanceThread(new std::thread([this] { this->run(); }))
Expand Down Expand Up @@ -40,7 +44,7 @@ void ProbeServer::onDtlsReceived(Endpoint& endpoint,
const SocketAddress& source,
const SocketAddress& target,
memory::UniquePacket packet,
const uint64_t timestamp){};
const uint64_t timestamp) {};

void ProbeServer::onRtcpReceived(Endpoint& endpoint,
const SocketAddress& source,
Expand All @@ -56,7 +60,12 @@ void ProbeServer::onIceReceived(Endpoint& endpoint,
memory::UniquePacket packet,
const uint64_t timestamp)
{
replyStunOk(endpoint, source, std::move(packet));
_jobQueue.post(utils::bind(&ProbeServer::onIceReceivedInternal,
this,
std::ref(endpoint),
source,
utils::moveParam(packet),
timestamp));
}

void ProbeServer::onRegistered(Endpoint& endpoint)
Expand Down Expand Up @@ -242,26 +251,24 @@ void ProbeServer::onIceTcpConnect(std::shared_ptr<Endpoint> endpoint,
memory::UniquePacket packet,
const uint64_t timestamp)
{
if (endpoint->getTransportType() == ice::TransportType::TCP)
{
replyStunOk(*endpoint, source, std::move(packet));

ProbeTcpConnection connection;
connection.endpoint = endpoint;
connection.timestamp = utils::Time::getAbsoluteTime();
_queue.push(connection);
}
_jobQueue.post(utils::bind(&ProbeServer::onIceTcpConnectInternal,
this,
endpoint,
source,
utils::moveParam(packet),
timestamp));
}

// Endpoint::IStopEvents
void ProbeServer::onEndpointStopped(Endpoint* endpoint) {}

void ProbeServer::replyStunOk(Endpoint& endpoint, const SocketAddress& destination, memory::UniquePacket packet)
bool ProbeServer::replyStunOk(Endpoint& endpoint,
const SocketAddress& destination,
memory::UniquePacket packet,
const uint64_t timestamp)
{
uint64_t timestamp = utils::Time::getAbsoluteTime();
const void* data = packet->get();

auto* stunMessage = ice::StunMessage::fromPtr(data);
auto* stunMessage = ice::StunMessage::fromPtr(packet->get());

if (stunMessage && stunMessage->isValid() && stunMessage->header.isRequest() &&
stunMessage->isAuthentic(_hmacComputer))
Expand All @@ -274,11 +281,40 @@ void ProbeServer::replyStunOk(Endpoint& endpoint, const SocketAddress& destinati
response.addFingerprint();

endpoint.sendStunTo(destination, response.header.transactionId.get(), &response, response.size(), timestamp);
return true;
}
else

if (endpoint.getTransportType() != ice::TransportType::UDP)
{
endpoint.stop(this);
}

return false;
}

void ProbeServer::onIceReceivedInternal(Endpoint& endpoint,
const SocketAddress& source,
memory::UniquePacket packet,
uint64_t timestamp)
{
replyStunOk(endpoint, source, std::move(packet), timestamp);
}

void ProbeServer::onIceTcpConnectInternal(std::shared_ptr<Endpoint> endpoint,
const SocketAddress& source,
memory::UniquePacket packet,
const uint64_t timestamp)
{
if (endpoint->getTransportType() == ice::TransportType::TCP)
{
if (replyStunOk(*endpoint, source, std::move(packet), timestamp))
{
ProbeTcpConnection connection;
connection.endpoint = endpoint;
connection.timestamp = utils::Time::getAbsoluteTime();
_queue.push(connection);
}
}
}

void ProbeServer::addCandidate(const ice::IceCandidate& candidate)
Expand Down
25 changes: 19 additions & 6 deletions transport/ProbeServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "crypto/SslHelper.h"
#include "ice/IceCandidate.h"
#include "ice/Stun.h"
#include "jobmanager/JobQueue.h"
#include "transport/Endpoint.h"
#include <config/Config.h>
#include <mutex>
Expand All @@ -15,8 +16,8 @@ class ProbeServer : public Endpoint::IEvents, public ServerEndpoint::IEvents, pu
{

public:
ProbeServer(const ice::IceConfig& iceConfig, const config::Config& config);
virtual ~ProbeServer(){};
ProbeServer(const ice::IceConfig& iceConfig, const config::Config& config, jobmanager::JobManager& jobmanager);
virtual ~ProbeServer() {};

// Endpoint::IEvents
void onRtpReceived(Endpoint&,
Expand Down Expand Up @@ -69,6 +70,21 @@ class ProbeServer : public Endpoint::IEvents, public ServerEndpoint::IEvents, pu
void run();
void stop();

private:
void onIceReceivedInternal(Endpoint& endpoint,
const SocketAddress& source,
memory::UniquePacket packet,
uint64_t timestamp);

void onIceTcpConnectInternal(std::shared_ptr<Endpoint> endpoint,
const SocketAddress& source,
memory::UniquePacket packet,
const uint64_t timestamp);

bool replyStunOk(Endpoint&, const SocketAddress&, memory::UniquePacket, const uint64_t timestamp);
void addCandidate(const ice::IceCandidate& candidate);
int getInterfaceIndex(transport::SocketAddress address);

private:
std::pair<std::string, std::string> _credentials;
const ice::IceConfig& _iceConfig;
Expand All @@ -86,13 +102,10 @@ class ProbeServer : public Endpoint::IEvents, public ServerEndpoint::IEvents, pu
};

crypto::HMAC _hmacComputer;
jobmanager::JobQueue _jobQueue;
std::vector<ProbeTcpConnection> _tcpConnections;
concurrency::MpmcQueue<ProbeTcpConnection> _queue;
std::atomic_bool _maintenanceThreadIsRunning;
std::unique_ptr<std::thread> _maintenanceThread;

void replyStunOk(Endpoint&, const SocketAddress&, memory::UniquePacket);
void addCandidate(const ice::IceCandidate& candidate);
int getInterfaceIndex(transport::SocketAddress address);
};
} // namespace transport

0 comments on commit 355922d

Please sign in to comment.