Skip to content

Commit

Permalink
Merge pull request #3 from Kilemonn/udp-improvements
Browse files Browse the repository at this point in the history
Udp improvements - Refactor the world
  • Loading branch information
Kilemonn authored Jun 28, 2024
2 parents 01f5ccf + 6d7684f commit d56b564
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 184 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ project(${PROJECT_NAME} VERSION 1.0)
set(HEADERS
src/serversocket/ServerSocket.h
src/socket/Socket.h
src/address/Address.h
src/address/SocketAddress.h
src/socketexceptions/BindingException.hpp
src/socketexceptions/SocketException.hpp
src/socketexceptions/TimeoutException.hpp
Expand All @@ -24,6 +24,7 @@ set(SOURCE
src/serversocket/ServerSocket.cpp
src/socket/Socket.cpp
src/socketexceptions/SocketError.cpp
src/address/SocketAddress.cpp
)

# Project Configuration - Adding as a lib
Expand Down
50 changes: 50 additions & 0 deletions src/address/SocketAddress.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "SocketAddress.h"

#include <optional>
#include <string>

namespace kt
{
kt::InternetProtocolVersion getInternetProtocolVersion(const kt::SocketAddress& address)
{
return static_cast<kt::InternetProtocolVersion>(address.address.sa_family);
}

long getPortNumber(const kt::SocketAddress& address)
{
kt::InternetProtocolVersion version = getInternetProtocolVersion(address);
if (version == kt::InternetProtocolVersion::IPV6)
{
return htonl(address.ipv6.sin6_port);
}
// I believe the address is in the same position for ipv4 and ipv6 structs, so it doesn't really matter.
// Doing the checks anway to make sure its fine
return htonl(address.ipv4.sin_port);
}

std::optional<std::string> resolveToAddress(const kt::SocketAddress& address)
{
const kt::InternetProtocolVersion protocolVersion = getInternetProtocolVersion(address);
const size_t addressLength = protocolVersion == kt::InternetProtocolVersion::IPV6 ? INET6_ADDRSTRLEN : INET_ADDRSTRLEN;
std::string asString;
asString.resize(addressLength);

if (protocolVersion == kt::InternetProtocolVersion::IPV6)
{
inet_ntop(static_cast<int>(protocolVersion), &address.ipv6.sin6_addr, &asString[0], addressLength);
}
else
{
inet_ntop(static_cast<int>(protocolVersion), &address.ipv4.sin_addr, &asString[0], addressLength);
}

// Removing trailing \0 bytes
const size_t delimiterIndex = asString.find_first_of('\0');
if (delimiterIndex != std::string::npos)
{
asString = asString.substr(0, delimiterIndex);
}
// Since we zero out the address, we need to check its not default initialised
return !asString.empty() && asString != "0.0.0.0" && asString != "::" ? std::optional<std::string>{asString} : std::nullopt;
}
}
11 changes: 11 additions & 0 deletions src/address/Address.h → src/address/SocketAddress.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@

#pragma once

#include "../enums/InternetProtocolVersion.h"

#include <string>
#include <optional>

#ifdef _WIN32

#ifndef WIN32_LEAN_AND_MEAN
Expand Down Expand Up @@ -33,4 +38,10 @@ namespace kt
sockaddr_in ipv4;
sockaddr_in6 ipv6;
} SocketAddress;

kt::InternetProtocolVersion getInternetProtocolVersion(const kt::SocketAddress&);

long getPortNumber(const kt::SocketAddress&);

std::optional<std::string> resolveToAddress(const kt::SocketAddress&);
}
74 changes: 37 additions & 37 deletions src/serversocket/ServerSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace kt
* @throw SocketException - If the ServerSocket is unable to be instanciated or begin listening.
* @throw BindingException - If the ServerSocket is unable to bind to the specific port specified.
*/
ServerSocket::ServerSocket(const kt::SocketType type, const unsigned int& port, const unsigned int& connectionBacklogSize, const InternetProtocolVersion protocolVersion)
kt::ServerSocket::ServerSocket(const kt::SocketType type, const unsigned int& port, const unsigned int& connectionBacklogSize, const kt::InternetProtocolVersion protocolVersion)
{
this->socketDescriptor = getInvalidSocketValue();
this->port = port;
Expand All @@ -78,7 +78,7 @@ namespace kt
*
* @param socket - The ServerSocket object to be copied.
*/
ServerSocket::ServerSocket(const ServerSocket& socket)
kt::ServerSocket::ServerSocket(const kt::ServerSocket& socket)
{
this->port = socket.port;
this->type = socket.type;
Expand All @@ -94,7 +94,7 @@ namespace kt
*
* @return the copied socket
*/
ServerSocket& ServerSocket::operator=(const ServerSocket& socket)
kt::ServerSocket& kt::ServerSocket::operator=(const kt::ServerSocket& socket)
{
this->port = socket.port;
this->type = socket.type;
Expand All @@ -105,7 +105,7 @@ namespace kt
return *this;
}

void ServerSocket::constructSocket(const unsigned int& connectionBacklogSize)
void kt::ServerSocket::constructSocket(const unsigned int& connectionBacklogSize)
{
if (this->type == kt::SocketType::Wifi)
{
Expand All @@ -118,10 +118,10 @@ namespace kt
}
}

void ServerSocket::constructBluetoothSocket(const unsigned int& connectionBacklogSize)
void kt::ServerSocket::constructBluetoothSocket(const unsigned int& connectionBacklogSize)
{
#ifdef _WIN32
throw SocketException("ServerSocket::constructBluetoothSocket(unsigned int) is not supported on Windows.");
throw kt::SocketException("ServerSocket::constructBluetoothSocket(unsigned int) is not supported on Windows.");

/*SOCKADDR_BTH bluetoothAddress;
Expand Down Expand Up @@ -176,7 +176,7 @@ namespace kt

}

void ServerSocket::constructWifiSocket(const unsigned int& connectionBacklogSize)
void kt::ServerSocket::constructWifiSocket(const unsigned int& connectionBacklogSize)
{
const int socketType = SOCK_STREAM;
const int socketProtocol = IPPROTO_TCP;
Expand All @@ -185,7 +185,7 @@ namespace kt
WSADATA wsaData{};
if (int res = WSAStartup(MAKEWORD(2, 2), &wsaData); res != 0)
{
throw SocketException("WSAStartup Failed: " + std::to_string(res));
throw kt::SocketException("WSAStartup Failed: " + std::to_string(res));
}
#endif

Expand All @@ -194,7 +194,7 @@ namespace kt
this->socketDescriptor = socket(static_cast<int>(this->protocolVersion), socketType, socketProtocol);
if (isInvalidSocket(this->socketDescriptor))
{
throw SocketException("Error establishing wifi server socket: " + getErrorCode());
throw kt::SocketException("Error establishing wifi server socket: " + getErrorCode());
}

#ifdef __linux__
Expand All @@ -206,20 +206,20 @@ namespace kt
#endif

#ifdef _WIN32
if (this->protocolVersion == InternetProtocolVersion::IPV6)
if (this->protocolVersion == kt::InternetProtocolVersion::IPV6)
{
const int disableOption = 0;
if (setsockopt(this->socketDescriptor, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&disableOption, sizeof(disableOption)) != 0)
{
throw SocketException("Failed to set IPV6_V6ONLY socket option: " + getErrorCode());
throw kt::SocketException("Failed to set IPV6_V6ONLY socket option: " + getErrorCode());
}
}
#endif

if (bind(this->socketDescriptor, &this->serverAddress.address, socketSize) == -1)
{
this->close();
throw BindingException("Error binding connection, the port " + std::to_string(this->port) + " is already being used: " + getErrorCode());
throw kt::BindingException("Error binding connection, the port " + std::to_string(this->port) + " is already being used: " + getErrorCode());
}

if (this->port == 0)
Expand All @@ -230,11 +230,11 @@ namespace kt
if (listen(this->socketDescriptor, connectionBacklogSize) == -1)
{
this->close();
throw SocketException("Error Listening on port " + std::to_string(this->port) + ": " + getErrorCode());
throw kt::SocketException("Error Listening on port " + std::to_string(this->port) + ": " + getErrorCode());
}
}

size_t ServerSocket::initialiseServerAddress()
size_t kt::ServerSocket::initialiseServerAddress()
{
addrinfo hint{};
memset(&this->serverAddress, 0, sizeof(this->serverAddress));
Expand All @@ -248,24 +248,24 @@ namespace kt
if (getaddrinfo(nullptr, std::to_string(this->port).c_str(), &hint, &addresses) != 0)
{
freeaddrinfo(addresses);
throw SocketException("Failed to retrieve address info of local hostname. " + getErrorCode());
throw kt::SocketException("Failed to retrieve address info of local hostname. " + getErrorCode());
}
this->protocolVersion = static_cast<InternetProtocolVersion>(addresses->ai_family);
this->protocolVersion = static_cast<kt::InternetProtocolVersion>(addresses->ai_family);
std::memcpy(&this->serverAddress, addresses->ai_addr, addresses->ai_addrlen);
freeaddrinfo(addresses);
return addresses->ai_addrlen;
}

void ServerSocket::initialisePortNumber()
void kt::ServerSocket::initialisePortNumber()
{
socklen_t socketSize = sizeof(this->serverAddress);
if (getsockname(this->socketDescriptor, &this->serverAddress.address, &socketSize) != 0)
{
this->close();
throw BindingException("Unable to retrieve randomly bound port number during socket creation. " + getErrorCode());
throw kt::BindingException("Unable to retrieve randomly bound port number during socket creation. " + getErrorCode());
}

if (this->protocolVersion == InternetProtocolVersion::IPV6)
if (this->protocolVersion == kt::InternetProtocolVersion::IPV6)
{
this->port = ntohs(this->serverAddress.ipv6.sin6_port);
}
Expand All @@ -276,9 +276,9 @@ namespace kt
}


void ServerSocket::setDiscoverable()
void kt::ServerSocket::setDiscoverable()
{
throw SocketException("ServerSocket::setDiscoverable() not implemented.");
throw kt::SocketException("ServerSocket::setDiscoverable() not implemented.");

#if __linux__
hci_dev_req req;
Expand All @@ -296,7 +296,7 @@ namespace kt
/**
* @return the *kt::SocketType* for this *kt::ServerSocket*.
*/
kt::SocketType ServerSocket::getType() const
kt::SocketType kt::ServerSocket::getType() const
{
return this->type;
}
Expand All @@ -305,15 +305,15 @@ namespace kt
* Used to get the port number that the ServerSocket is listening on.
* @return An unsigned int of the port number that the ServerSocket is listening on.
*/
unsigned int ServerSocket::getPort() const
unsigned int kt::ServerSocket::getPort() const
{
return this->port;
}

/**
* @return the *kt::InternetProtocolVersion* for this *kt::ServerSocket*.
*/
InternetProtocolVersion ServerSocket::getInternetProtocolVersion() const
kt::InternetProtocolVersion kt::ServerSocket::getInternetProtocolVersion() const
{
return this->protocolVersion;
}
Expand All @@ -326,23 +326,23 @@ namespace kt
*
* @returns kt::Socket object of the receiver who has just connected to the kt::ServerSocket.
*/
Socket ServerSocket::accept(const long& timeout)
kt::Socket kt::ServerSocket::accept(const long& timeout)
{
if (this->type == SocketType::Wifi)
if (this->type == kt::SocketType::Wifi)
{
return this->acceptWifiConnection(timeout);
}
else if (this->type == SocketType::Bluetooth)
else if (this->type == kt::SocketType::Bluetooth)
{
return this->acceptBluetoothConnection(timeout);
}
else
{
throw SocketException("Cannot accept connection with SocketType set as SocketType::None");
throw kt::SocketException("Cannot accept connection with SocketType set as SocketType::None");
}
}

Socket ServerSocket::acceptWifiConnection(const long& timeout)
kt::Socket kt::ServerSocket::acceptWifiConnection(const long& timeout)
{
if (timeout > 0)
{
Expand All @@ -357,25 +357,25 @@ namespace kt
}
}

SocketAddress acceptedAddress{};
kt::SocketAddress acceptedAddress{};
socklen_t sockLen = sizeof(acceptedAddress);
SOCKET temp = ::accept(this->socketDescriptor, &acceptedAddress.address, &sockLen);
if (isInvalidSocket(temp))
{
throw SocketException("Failed to accept connection. Socket is in an invalid state.");
throw kt::SocketException("Failed to accept connection. Socket is in an invalid state.");
}

unsigned int portNum = this->getInternetProtocolVersion() == InternetProtocolVersion::IPV6 ? htons(acceptedAddress.ipv6.sin6_port) : htons(acceptedAddress.ipv4.sin_port);
std::optional<std::string> hostname = kt::resolveToAddress(&acceptedAddress, this->getInternetProtocolVersion());
unsigned int portNum = this->getInternetProtocolVersion() == kt::InternetProtocolVersion::IPV6 ? htons(acceptedAddress.ipv6.sin6_port) : htons(acceptedAddress.ipv4.sin_port);
std::optional<std::string> hostname = kt::resolveToAddress(acceptedAddress);
if (!hostname.has_value())
{
throw SocketException("Unable to resolve accepted hostname from accepted socket.");
throw kt::SocketException("Unable to resolve accepted hostname from accepted socket.");
}

return Socket(temp, this->type, kt::SocketProtocol::TCP, hostname.value(), portNum, this->getInternetProtocolVersion());
return kt::Socket(temp, this->type, kt::SocketProtocol::TCP, hostname.value(), portNum, this->getInternetProtocolVersion());
}

Socket ServerSocket::acceptBluetoothConnection(const long& timeout)
kt::Socket kt::ServerSocket::acceptBluetoothConnection(const long& timeout)
{
if (timeout > 0)
{
Expand All @@ -391,7 +391,7 @@ namespace kt
}

#ifdef __linux__
throw SocketException("Not yet implemented.");
throw kt::SocketException("Not yet implemented.");
// Remove bluetooth related code

// sockaddr_rc remoteDevice = { 0 };
Expand Down
24 changes: 12 additions & 12 deletions src/serversocket/ServerSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <optional>

#include "../address/Address.h"
#include "../address/SocketAddress.h"

#include "../socket/Socket.h"

Expand Down Expand Up @@ -38,10 +38,10 @@ namespace kt
{
class ServerSocket
{
private:
protected:
unsigned int port;
SocketType type = SocketType::None;
InternetProtocolVersion protocolVersion = InternetProtocolVersion::Any;
kt::SocketType type = kt::SocketType::None;
kt::InternetProtocolVersion protocolVersion = kt::InternetProtocolVersion::Any;
kt::SocketAddress serverAddress = {};
SOCKET socketDescriptor = getInvalidSocketValue();

Expand All @@ -52,20 +52,20 @@ namespace kt
void initialisePortNumber();
size_t initialiseServerAddress();

Socket acceptWifiConnection(const long& = 0);
Socket acceptBluetoothConnection(const long& = 0);
kt::Socket acceptWifiConnection(const long& = 0);
kt::Socket acceptBluetoothConnection(const long& = 0);

public:
ServerSocket() = default;
ServerSocket(const SocketType, const unsigned int& = 0, const unsigned int& = 20, const InternetProtocolVersion = InternetProtocolVersion::Any);
ServerSocket(const ServerSocket&);
ServerSocket& operator=(const ServerSocket&);
ServerSocket(const kt::SocketType, const unsigned int& = 0, const unsigned int& = 20, const kt::InternetProtocolVersion = kt::InternetProtocolVersion::Any);
ServerSocket(const kt::ServerSocket&);
kt::ServerSocket& operator=(const kt::ServerSocket&);

SocketType getType() const;
InternetProtocolVersion getInternetProtocolVersion() const;
kt::SocketType getType() const;
kt::InternetProtocolVersion getInternetProtocolVersion() const;
unsigned int getPort() const;

Socket accept(const long& = 0);
kt::Socket accept(const long& = 0);
void close();
};

Expand Down
Loading

0 comments on commit d56b564

Please sign in to comment.