diff --git a/common/Unit.hpp b/common/Unit.hpp index bd2cdfd71553..4bac605c6c79 100644 --- a/common/Unit.hpp +++ b/common/Unit.hpp @@ -503,6 +503,15 @@ class UnitWSD : public UnitBase return false; } + // ---------------- ServerSocket hooks ---------------- + /// Simulate `::accept` errors for external `ServerSocket::accept`. Implement unrecoverable errors by throwing an exception. + virtual bool simulateExternalAcceptError() + { + return false; + } + /// Simulate exceptions during `StreamSocket` constructor for external `ServerSocket::accept`. + virtual void simulateExternalSocketCtorException(std::shared_ptr& /*socket*/) { } + // ---------------- TileCache hooks ---------------- /// Called before the lookupTile call returns. Should always be called to fire events. virtual void lookupTile(int part, int mode, int width, int height, int tilePosX, int tilePosY, diff --git a/net/ServerSocket.hpp b/net/ServerSocket.hpp index 27ba9c49bff4..91098fade531 100644 --- a/net/ServerSocket.hpp +++ b/net/ServerSocket.hpp @@ -104,22 +104,16 @@ class ServerSocket : public Socket if (events & POLLIN) { std::shared_ptr clientSocket = accept(); - if (!clientSocket) + if (clientSocket) { - const std::string msg = "Failed to accept. (errno: "; - throw std::runtime_error(msg + std::strerror(errno) + ')'); - } - const size_t extConnCount = StreamSocket::getExternalConnectionCount(); - if( 0 == net::Defaults.maxExtConnections || extConnCount <= net::Defaults.maxExtConnections ) - { - LOG_TRC("Accepted client #" << clientSocket->getFD()); + LOGA_TRC(Socket, "Accepted client #" << clientSocket->getFD() << ", " << *clientSocket); _clientPoller.insertNewSocket(std::move(clientSocket)); - } else - LOG_WRN("Limiter rejected extConn[" << extConnCount << "/" << net::Defaults.maxExtConnections << "]: " << *clientSocket); + } } } protected: + bool isUnrecoverableAcceptError(const int cause); /// Create a Socket instance from the accepted socket FD. std::shared_ptr createSocketFromAccept(int fd, Socket::Type type) const { diff --git a/net/Socket.cpp b/net/Socket.cpp index 7f317139066e..4a8f4e6f5e72 100644 --- a/net/Socket.cpp +++ b/net/Socket.cpp @@ -15,6 +15,7 @@ #include "TraceEvent.hpp" #include "Util.hpp" +#include #include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1149,6 +1151,41 @@ bool ServerSocket::bind([[maybe_unused]] Type type, [[maybe_unused]] int port) #endif } +bool ServerSocket::isUnrecoverableAcceptError(const int cause) +{ + constexpr const char * messagePrefix = "Failed to accept. (errno: "; + switch(cause) + { + case EINTR: + case EAGAIN: // == EWOULDBLOCK + case ENETDOWN: + case EPROTO: + case ENOPROTOOPT: + case EHOSTDOWN: +#ifdef ENONET + case ENONET: +#endif + case EHOSTUNREACH: + case EOPNOTSUPP: + case ENETUNREACH: + case ECONNABORTED: + case ETIMEDOUT: + case EMFILE: + case ENFILE: + case ENOMEM: + case ENOBUFS: + { + LOG_DBG(messagePrefix << std::to_string(cause) << ", " << std::strerror(cause) << ')'); + return false; + } + default: + { + LOG_FTL(messagePrefix << std::to_string(cause) << ", " << std::strerror(cause) << ')'); + return true; + } + } +} + std::shared_ptr ServerSocket::accept() { // Accept a connection (if any) and set it to non-blocking. @@ -1156,56 +1193,73 @@ std::shared_ptr ServerSocket::accept() #if !MOBILEAPP assert(_type != Socket::Type::Unix); + UnitWSD* const unitWsd = UnitWSD::isUnitTesting() ? &UnitWSD::get() : nullptr; + if (unitWsd && unitWsd->simulateExternalAcceptError()) + return nullptr; // Recoverable error, ignore to retry + struct sockaddr_in6 clientInfo; socklen_t addrlen = sizeof(clientInfo); const int rc = ::accept4(getFD(), (struct sockaddr *)&clientInfo, &addrlen, SOCK_NONBLOCK | SOCK_CLOEXEC); #else const int rc = fakeSocketAccept4(getFD()); #endif - LOG_TRC("Accepted socket #" << rc << ", creating socket object."); - try + if (rc < 0) { - // Create a socket object using the factory. - if (rc != -1) - { + if (isUnrecoverableAcceptError(errno)) + Util::forcedExit(EX_SOFTWARE); + return nullptr; + } + LOG_TRC("Accepted socket #" << rc << ", creating socket object."); + #if !MOBILEAPP - char addrstr[INET6_ADDRSTRLEN]; + char addrstr[INET6_ADDRSTRLEN]; - Socket::Type type; - const void *inAddr; - if (clientInfo.sin6_family == AF_INET) - { - struct sockaddr_in *ipv4 = (struct sockaddr_in *)&clientInfo; - inAddr = &(ipv4->sin_addr); - type = Socket::Type::IPv4; - } - else - { - struct sockaddr_in6 *ipv6 = &clientInfo; - inAddr = &(ipv6->sin6_addr); - type = Socket::Type::IPv6; - } + Socket::Type type; + const void *inAddr; + if (clientInfo.sin6_family == AF_INET) + { + struct sockaddr_in *ipv4 = (struct sockaddr_in *)&clientInfo; + inAddr = &(ipv4->sin_addr); + type = Socket::Type::IPv4; + } + else + { + struct sockaddr_in6 *ipv6 = &clientInfo; + inAddr = &(ipv6->sin6_addr); + type = Socket::Type::IPv6; + } + ::inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr)); - std::shared_ptr _socket = createSocketFromAccept(rc, type); + const size_t extConnCount = StreamSocket::getExternalConnectionCount(); + if (net::Defaults.maxExtConnections > 0 && extConnCount >= net::Defaults.maxExtConnections) + { + LOG_WRN("Limiter rejected extConn[" << extConnCount << "/" << net::Defaults.maxExtConnections << "]: #" + << rc << " has family " + << clientInfo.sin6_family << ", address " << addrstr << ":" << clientInfo.sin6_port); + ::close(rc); + return nullptr; + } + try + { + // Create a socket object using the factory. + std::shared_ptr _socket = createSocketFromAccept(rc, type); + if (unitWsd) + unitWsd->simulateExternalSocketCtorException(_socket); - ::inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr)); - _socket->setClientAddress(addrstr, clientInfo.sin6_port); + _socket->setClientAddress(addrstr, clientInfo.sin6_port); - LOG_TRC("Accepted socket #" << _socket->getFD() << " has family " - << clientInfo.sin6_family << ", " << *_socket); -#else - std::shared_ptr _socket = createSocketFromAccept(rc, Socket::Type::Unix); -#endif - return _socket; - } - return std::shared_ptr(nullptr); + LOG_TRC("Accepted socket #" << _socket->getFD() << " has family " + << clientInfo.sin6_family << ", " << *_socket); + return _socket; } catch (const std::exception& ex) { LOG_ERR("Failed to create client socket #" << rc << ". Error: " << ex.what()); } - return nullptr; +#else + return createSocketFromAccept(rc, Socket::Type::Unix); +#endif } #if !MOBILEAPP @@ -1251,11 +1305,15 @@ bool Socket::isLocal() const std::shared_ptr LocalServerSocket::accept() { const int rc = ::accept4(getFD(), nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC); + if (rc < 0) + { + if (isUnrecoverableAcceptError(errno)) + Util::forcedExit(EX_SOFTWARE); + return nullptr; + } try { LOG_DBG("Accepted prisoner socket #" << rc << ", creating socket object."); - if (rc < 0) - return std::shared_ptr(nullptr); std::shared_ptr _socket = createSocketFromAccept(rc, Socket::Type::Unix); // Sanity check this incoming socket @@ -1305,8 +1363,8 @@ std::shared_ptr LocalServerSocket::accept() catch (const std::exception& ex) { LOG_ERR("Failed to create client socket #" << rc << ". Error: " << ex.what()); - return std::shared_ptr(nullptr); } + return nullptr; } /// Returns true on success only. diff --git a/test/Makefile.am b/test/Makefile.am index 46132e2ce61d..d3cb1b3d0975 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -94,6 +94,8 @@ all_la_unit_tests = \ unit-timeout_inactive.la \ unit-timeout_conn.la \ unit-timeout_none.la \ + unit-serversock_accept1.la \ + unit-streamsock_ctor1.la \ unit-base.la # unit-admin.la # unit-tilecache.la # Empty test. @@ -232,6 +234,10 @@ unit_timeout_conn_la_SOURCES = UnitTimeoutConnections.cpp unit_timeout_conn_la_LIBADD = $(CPPUNIT_LIBS) unit_timeout_none_la_SOURCES = UnitTimeoutNone.cpp unit_timeout_none_la_LIBADD = $(CPPUNIT_LIBS) +unit_serversock_accept1_la_SOURCES = UnitServerSocketAcceptFailure1.cpp +unit_serversock_accept1_la_LIBADD = $(CPPUNIT_LIBS) +unit_streamsock_ctor1_la_SOURCES = UnitStreamSocketCtorFailure1.cpp +unit_streamsock_ctor1_la_LIBADD = $(CPPUNIT_LIBS) unit_prefork_la_SOURCES = UnitPrefork.cpp unit_prefork_la_LIBADD = $(CPPUNIT_LIBS) unit_storage_la_SOURCES = UnitStorage.cpp diff --git a/test/UnitServerSocketAcceptFailure1.cpp b/test/UnitServerSocketAcceptFailure1.cpp new file mode 100644 index 000000000000..3dabf092d2f0 --- /dev/null +++ b/test/UnitServerSocketAcceptFailure1.cpp @@ -0,0 +1,166 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */ +/* + * Copyright the Collabora Online contributors. + * + * SPDX-License-Identifier: MPL-2.0 + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include + +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include "UnitTimeoutBase.hpp" + +/// Test suite class for injected ServerSocket accept failures +class UnitServerSocketAcceptFailure1 : public UnitTimeoutBase0 +{ + TestResult testHttp(); + +public: + UnitServerSocketAcceptFailure1() + : UnitTimeoutBase0("UnitServerSocketAcceptFailure1") + { + } + + void invokeWSDTest() override; + bool simulateExternalAcceptError() override; + +private: + static constexpr size_t ExternalServerSocketAcceptSimpleErrorInterval = 2; + static constexpr size_t ExternalServerSocketAcceptFatalErrorInterval = 5; + std::atomic _externalServerSocketAcceptCount = 0; +}; + +bool UnitServerSocketAcceptFailure1::simulateExternalAcceptError() +{ + const size_t acceptCount = ++_externalServerSocketAcceptCount; + if (acceptCount % ExternalServerSocketAcceptSimpleErrorInterval == 0) + { + // recoverable error like EAGAIN + LOG_DBG("Injecting recoverable accept failure: EAGAIN: " << acceptCount); + return true; + } + else if (acceptCount % ExternalServerSocketAcceptFatalErrorInterval == 0) + { + // fatal error like EFAULT + LOG_DBG("Injecting fatal accept failure: EAGAIN: " << acceptCount); + throw std::runtime_error("Injecting fatal accept failure."); + } + else + return false; +} + +inline UnitBase::TestResult UnitServerSocketAcceptFailure1::testHttp() +{ + const size_t fatal_iter = ExternalServerSocketAcceptFatalErrorInterval - ExternalServerSocketAcceptSimpleErrorInterval; + setTestname(__func__); + TST_LOG("Starting Test: " << testname + << ": ServerSocketAcceptSimpleErrorInterval " + << ExternalServerSocketAcceptSimpleErrorInterval + << ", ServerSocketAcceptFatalErrorInterval " + << ExternalServerSocketAcceptFatalErrorInterval + << ", fatal_iter (client) " << fatal_iter); + + const std::string documentURL = "/favicon.ico"; + + constexpr bool UseOwnPoller = true; + constexpr bool PollerOnClientThread = true; + std::shared_ptr socketPoller; + std::shared_ptr session; + + for(size_t iteration = 1; iteration <= fatal_iter; ++iteration) + { + const bool expectFatalError = iteration == fatal_iter; + TST_LOG("Test[" << iteration << "]: expectFatalError " << expectFatalError); + + if( UseOwnPoller ) + { + socketPoller = std::make_shared(testname); + if( PollerOnClientThread ) + socketPoller->runOnClientThread(); + else + socketPoller->startThread(); + } else + socketPoller = socketPoll(); + + session = http::Session::create(helpers::getTestServerURI()); + bool connected00 = false; + { + TST_LOG("Test[" << iteration << "] Req1: " << testname << ": `" << documentURL << "`"); + http::Request request(documentURL, http::Request::VERB_GET); + const std::shared_ptr response = + session->syncRequest(request, *socketPoller); + TST_LOG("Test[" << iteration << "] Connected: " << session->isConnected()); + TST_LOG("Test[" << iteration << "] Response1: " << response->header().toString()); + TST_LOG("Test[" << iteration << "] Response1 size: " << testname << ": `" << documentURL << "`: " << response->header().getContentLength()); + if( session->isConnected() ) { + connected00 = true; + LOK_ASSERT_EQUAL(http::StatusCode::OK, response->statusCode()); + LOK_ASSERT(http::Header::ConnectionToken::None == + response->header().getConnectionToken()); + LOK_ASSERT(0 < response->header().getContentLength()); + } else { + // connection limit hit + LOK_ASSERT_EQUAL(http::StatusCode::None, response->statusCode()); + } + } + bool connected01 = false; + { + TST_LOG("Test[" << iteration << "] SessionA " << ": connected " << session->isConnected()); + if( session->isConnected() ) + { + connected01 = true; + session->asyncShutdown(); + } + if( UseOwnPoller ) { + if( PollerOnClientThread ) + { + socketPoller->closeAllSockets(); + } else { + socketPoller->joinThread(); + } + } + } + if( expectFatalError ) + { + LOK_ASSERT(false == connected00); + LOK_ASSERT(false == connected01); + } else { + LOK_ASSERT(true == connected00); + LOK_ASSERT(true == connected01); + } + } // for iterations + TST_LOG("Ending Test: " << testname); + return TestResult::Ok; +} + +void UnitServerSocketAcceptFailure1::invokeWSDTest() +{ + UnitBase::TestResult result; + + result = testHttp(); + if (result != TestResult::Ok) + exitTest(result); + + exitTest(TestResult::Ok); +} + +UnitBase* unit_create_wsd(void) { return new UnitServerSocketAcceptFailure1(); } + +/* vim:set shiftwidth=4 softtabstop=4 expandtab: */ diff --git a/test/UnitStreamSocketCtorFailure1.cpp b/test/UnitStreamSocketCtorFailure1.cpp new file mode 100644 index 000000000000..aa79cce84ee7 --- /dev/null +++ b/test/UnitStreamSocketCtorFailure1.cpp @@ -0,0 +1,154 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */ +/* + * Copyright the Collabora Online contributors. + * + * SPDX-License-Identifier: MPL-2.0 + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include + +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include "UnitTimeoutBase.hpp" + +/// Test suite class for injected StreamSocket ctor exceptions, handled by ServerSocket::accept +class UnitStreamSocketCtorFailure1 : public UnitTimeoutBase0 +{ + TestResult testHttp(); + +public: + UnitStreamSocketCtorFailure1() + : UnitTimeoutBase0("UnitStreamSocketCtorFailure1") + { + } + + void invokeWSDTest() override; + void simulateExternalSocketCtorException(std::shared_ptr& socket) override; + +private: + static constexpr size_t ExternalStreamSocketCtorFailureInterval = 2; + std::atomic _externalStreamSocketCount = 0; +}; + +void UnitStreamSocketCtorFailure1::simulateExternalSocketCtorException(std::shared_ptr& socket) +{ + const size_t extStreamSocketCount = ++_externalStreamSocketCount; + if (extStreamSocketCount % ExternalStreamSocketCtorFailureInterval == 0) + { + LOG_DBG("Injecting recoverable StreamSocket ctor exception " << extStreamSocketCount + << "/" << ExternalStreamSocketCtorFailureInterval << ": " << socket); + throw std::runtime_error("Test: StreamSocket exception: fd " + std::to_string(socket->getFD())); + } +} + +inline UnitBase::TestResult UnitStreamSocketCtorFailure1::testHttp() +{ + setTestname(__func__); + TST_LOG("Starting Test: " << testname + << ": StreamSocketCtorFailureInterval " + << ExternalStreamSocketCtorFailureInterval); + + const std::string documentURL = "/favicon.ico"; + + constexpr bool UseOwnPoller = true; + constexpr bool PollerOnClientThread = true; + std::shared_ptr socketPoller; + std::shared_ptr session; + const size_t iter_max = 2*ExternalStreamSocketCtorFailureInterval+1; + + for(size_t iteration = 1; iteration <= iter_max; ++iteration) + { + const bool expectFailure = 0 == iteration % ExternalStreamSocketCtorFailureInterval; + TST_LOG("Test[" << iteration << "]: expectFailure " << expectFailure); + + if( UseOwnPoller ) + { + socketPoller = std::make_shared(testname); + if( PollerOnClientThread ) + socketPoller->runOnClientThread(); + else + socketPoller->startThread(); + } else + socketPoller = socketPoll(); + + session = http::Session::create(helpers::getTestServerURI()); + bool connected00 = false; + { + TST_LOG("Test[" << iteration << "] Req1: " << testname << ": `" << documentURL << "`"); + http::Request request(documentURL, http::Request::VERB_GET); + const std::shared_ptr response = + session->syncRequest(request, *socketPoller); + TST_LOG("Test[" << iteration << "] Connected: " << session->isConnected()); + TST_LOG("Test[" << iteration << "] Response1: " << response->header().toString()); + TST_LOG("Test[" << iteration << "] Response1 size: " << testname << ": `" << documentURL << "`: " << response->header().getContentLength()); + if( session->isConnected() ) { + connected00 = true; + LOK_ASSERT_EQUAL(http::StatusCode::OK, response->statusCode()); + LOK_ASSERT(http::Header::ConnectionToken::None == + response->header().getConnectionToken()); + LOK_ASSERT(0 < response->header().getContentLength()); + } else { + // connection limit hit + LOK_ASSERT_EQUAL(http::StatusCode::None, response->statusCode()); + } + } + bool connected01 = false; + { + TST_LOG("Test[" << iteration << "] SessionA " << ": connected " << session->isConnected()); + if( session->isConnected() ) + { + connected01 = true; + session->asyncShutdown(); + } + if( UseOwnPoller ) { + if( PollerOnClientThread ) + { + socketPoller->closeAllSockets(); + } else { + socketPoller->joinThread(); + } + } + } + if( expectFailure ) + { + LOK_ASSERT(false == connected00); + LOK_ASSERT(false == connected01); + } else { + LOK_ASSERT(true == connected00); + LOK_ASSERT(true == connected01); + } + } // for iterations + TST_LOG("Ending Test: " << testname); + return TestResult::Ok; +} + +void UnitStreamSocketCtorFailure1::invokeWSDTest() +{ + UnitBase::TestResult result; + + result = testHttp(); + if (result != TestResult::Ok) + exitTest(result); + + exitTest(TestResult::Ok); +} + +UnitBase* unit_create_wsd(void) { return new UnitStreamSocketCtorFailure1(); } + +/* vim:set shiftwidth=4 softtabstop=4 expandtab: */ diff --git a/test/UnitTimeoutWSPing.cpp b/test/UnitTimeoutWSPing.cpp deleted file mode 100644 index 00c925d8851a..000000000000 --- a/test/UnitTimeoutWSPing.cpp +++ /dev/null @@ -1,95 +0,0 @@ -/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */ -/* - * Copyright the Collabora Online contributors. - * - * SPDX-License-Identifier: MPL-2.0 - * - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this - * file, You can obtain one at http://mozilla.org/MPL/2.0/. - */ - -#include - -#include -#include - -#include -#include - -#include -#include - -#include -#include -#include -#include - -#include "UnitTimeoutBase.hpp" - -/// Test suite class for WS Ping (native frame) timeout limit using a WS sessions. -class UnitTimeoutWSPing : public UnitTimeoutBase0 -{ - TestResult testWSPing(); - - void configure(Poco::Util::LayeredConfiguration& /* config */) override - { - net::Defaults.wsPingAvgTimeout = std::chrono::microseconds(20); - net::Defaults.wsPingInterval = std::chrono::milliseconds(10); - } - -public: - UnitTimeoutWSPing() - : UnitTimeoutBase0("UnitTimeoutWSPing") - { - } - - void invokeWSDTest() override; -}; - -/// Attempt to test the native WebSocket control-frame ping/pong facility -> Timeout! -UnitBase::TestResult UnitTimeoutWSPing::testWSPing() -{ - setTestname(__func__); - TST_LOG("Starting Test: " << testname); - - std::string documentPath, documentURL; - helpers::getDocumentPathAndURL("hello.odt", documentPath, documentURL, testname); - - // NOTE: Do not replace with wrappers. This has to be explicit. - std::shared_ptr session = http::WebSocketSession::create(helpers::getTestServerURI()); - http::Request req(documentURL); - session->asyncRequest(req, socketPoll()); - - // wsd/ClientSession.cpp:709 sendTextFrameAndLogError("error: cmd=" + tokens[0] + " kind=nodocloaded"); - constexpr const bool loadDoc = true; // Required for WSD chat -> wsd/ClientSession.cpp:709, common/Session.hpp:160 - if( loadDoc ) { - session->sendMessage("load url=" + documentURL); - } - - LOK_ASSERT_EQUAL(true, session->isConnected()); - - assertMessage(*session, "progress:", "find"); - assertMessage(*session, "progress:", "connect"); - assertMessage(*session, "progress:", "ready"); - - LOK_ASSERT_EQUAL(true, pollDisconnected(std::chrono::microseconds(1000000), *session)); - - TST_LOG("Ending Test: " << testname); - return TestResult::Ok; -} - -void UnitTimeoutWSPing::invokeWSDTest() -{ - UnitBase::TestResult result; - - result = testWSPing(); - if (result != TestResult::Ok) - exitTest(result); - - exitTest(TestResult::Ok); -} - -UnitBase* unit_create_wsd(void) { return new UnitTimeoutWSPing(); } - -/* vim:set shiftwidth=4 softtabstop=4 expandtab: */