Skip to content

Commit

Permalink
WIP: use transport
Browse files Browse the repository at this point in the history
Currently broken
  • Loading branch information
Sjors committed Jan 5, 2024
1 parent 7e89e1b commit 229fe0d
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 353 deletions.
101 changes: 0 additions & 101 deletions src/common/sv2_noise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,104 +406,3 @@ void Sv2Cipher::EncryptMessage(Span<std::byte> input, Span<std::byte> output)
m_cs2.EncryptMessage(input, output);
}
}

bool Sv2NoiseSession::ProcessMaybeHandshake(Span<std::byte> msg, bool send)
{
switch (m_session_state)
{
case SessionState::HANDSHAKE_STEP_1:
{
if (send) {
m_handshake_state.WriteMsgEphemeralPK(msg);
} else {
m_handshake_state.ReadMsgEphemeralPK(msg);
}

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session state -> HANDSHAKE_STEP_2\n");
m_session_state = SessionState::HANDSHAKE_STEP_2;
break;
}
case SessionState::HANDSHAKE_STEP_2:
{
if (send) {
m_handshake_state.WriteMsgES(msg);
} else {
bool res = m_handshake_state.ReadMsgES(msg);
if (!res) return false;
}

auto cipher_state = m_handshake_state.m_symmetric_state.Split();
auto cs1 = cipher_state[0];
auto cs2 = cipher_state[1];

m_hash = std::move(m_handshake_state.m_symmetric_state.m_hash_output);
m_cs1 = std::move(cs1);
m_cs2 = std::move(cs2);

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session state -> TRANSPORT\n");
m_session_state = SessionState::TRANSPORT;
break;
}
case SessionState::TRANSPORT:
{
Assume(false);
break;
}
}
return true;
}

Sv2NoiseSession::Sv2NoiseSession(bool initiator, CKey&& static_key): m_initiator{initiator}
{
m_handshake_state = Sv2HandshakeState(std::move(static_key));
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session state -> HANDSHAKE_STEP_1\n");
m_session_state = SessionState::HANDSHAKE_STEP_1;
}

void Sv2NoiseSession::EncryptMessage(Span<std::byte> input, Span<std::byte> output)
{
Assume(m_session_state == SessionState::TRANSPORT);
Assume(output.size() == Sv2NoiseSession::EncryptedMessageSize(input.size()));

if (m_initiator) {
m_cs1.EncryptMessage(input, output);
} else {
m_cs2.EncryptMessage(input, output);
}
}

bool Sv2NoiseSession::DecryptMessage(Span<std::byte> message)
{
Assume(m_session_state == SessionState::TRANSPORT);

if (m_initiator) {
return m_cs2.DecryptMessage(message);
} else {
return m_cs1.DecryptMessage(message);
}
}

const uint256& Sv2NoiseSession::GetSymmetricStateHash() const
{
return m_hash;
}

const SessionState& Sv2NoiseSession::GetSessionState() const
{
return m_session_state;
}

void Sv2HandshakeState::GenerateEphemeralKey(CKey& key) noexcept
{
Assume(!key.size());
key.MakeNewKey(true);
Assume(XOnlyPubKey(key.GetPubKey()).IsFullyValid());
};

size_t Sv2NoiseSession::EncryptedMessageSize(size_t msg_len) {
size_t num_chunks = msg_len / (NOISE_MAX_CHUNK_SIZE - POLY1305_TAGLEN);
if (msg_len % (NOISE_MAX_CHUNK_SIZE - POLY1305_TAGLEN) != 0) {
num_chunks++;
}
return msg_len + (num_chunks * POLY1305_TAGLEN);
}
44 changes: 0 additions & 44 deletions src/common/sv2_noise.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,52 +250,8 @@ class Sv2Cipher
class Sv2NoiseSession
{
public:
Sv2HandshakeState m_handshake_state;

Sv2NoiseSession(bool initiator, CKey&& static_key);

/**
* Process a noise msg to keep a handshake progressing
* May not be called in TRANSPORT state
* @throws std::runtime_error if the msg cannot be processed
* TODO: just return false
*/
[[ nodiscard ]] bool ProcessMaybeHandshake(Span<std::byte> msg, bool send);

/** Encrypt a message. Only call in TRANSPORT session state.
*
* @param[in] input message to be encrypted
* @param[out] output use EncryptedMessageSize() to get the correct size,
* must point to a different underlying buffer.
*/
void EncryptMessage(Span<std::byte> input, Span<std::byte> output);

/** Decrypt a message. Only call in TRANSPORT session state.
* The shorter decrypted chunks are concatenated and written
* back to msg.
*
* @param[in] message message to be decrypted
*
* @returns whether decryption succeeded
*/
[[ nodiscard ]] bool DecryptMessage(Span<std::byte> message);
const uint256& GetSymmetricStateHash() const;
const SessionState& GetSessionState() const;
bool HandshakeComplete() const
{
return m_session_state == SessionState::TRANSPORT;
}
/* Expected size after chunking and with MAC */
static size_t EncryptedMessageSize(size_t msg_len);

private:
bool m_initiator;

SessionState m_session_state;

uint256 m_hash;
Sv2CipherState m_cs1;
Sv2CipherState m_cs2;
};

#endif // BITCOIN_COMMON_SV2_NOISE_H
2 changes: 1 addition & 1 deletion src/common/sv2_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ CKey GenerateRandomKey() noexcept
return key;
}

Sv2Transport::Sv2Transport(bool initiating, CKey&& static_key) noexcept
Sv2Transport::Sv2Transport(bool initiating, CKey static_key) noexcept
: m_cipher{Sv2Cipher(std::move(static_key), initiating)}, m_initiating{initiating},
m_recv_state{initiating ? RecvState::HANDSHAKE_STEP_2 : RecvState::HANDSHAKE_STEP_1},
m_send_state{initiating ? SendState::HANDSHAKE_STEP_1 : SendState::HANDSHAKE_STEP_2},
Expand Down
2 changes: 1 addition & 1 deletion src/common/sv2_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class Sv2Transport final : public Transport
* @param[in] static_key a securely generated key
*/
Sv2Transport(bool initiating, CKey&& static_key) noexcept;
Sv2Transport(bool initiating, CKey static_key) noexcept;

// Receive side functions.
bool ReceivedMessageComplete() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
Expand Down
148 changes: 42 additions & 106 deletions src/node/sv2_template_provider.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
#include <node/sv2_template_provider.h>

#include <common/args.h>
#include <common/sv2_transport.h>
#include <consensus/merkle.h>
#include <txmempool.h>
#include <util/thread.h>
#include <validation.h>

Sv2TemplateProvider::Sv2TemplateProvider(ChainstateManager& chainman, CTxMemPool& mempool) : m_chainman{chainman}, m_mempool{mempool}
{
// TODO: persist static key
m_static_key.MakeNewKey(true);

// TODO: get rid of Init() ???
Init({});
}

bool Sv2TemplateProvider::Start(const Sv2TemplateProviderOptions& options)
{
Init(options);
Expand Down Expand Up @@ -217,7 +225,8 @@ void Sv2TemplateProvider::ThreadSv2Handler()

auto sock = m_listening_socket->Accept(reinterpret_cast<struct sockaddr*>(&sockaddr), &sockaddr_len);
if (sock) {
m_sv2_clients.emplace_back(std::make_unique<Sv2Client>(Sv2Client{std::move(sock)}));
m_sv2_clients.push_back(std::make_unique<Sv2Client>(Sv2Client{std::move(sock)}));
m_sv2_clients[-1]->m_transport = std::make_unique<Sv2Transport>(/*initiating=*/false, m_static_key);
}
}

Expand Down Expand Up @@ -250,17 +259,37 @@ void Sv2TemplateProvider::ThreadSv2Handler()
try
{
auto msg_ = Span(bytes_received_buf, num_bytes_received);
Span<std::byte> msg(reinterpret_cast<std::byte*>(msg_.data()), msg_.size());
Span<const uint8_t> msg(reinterpret_cast<const uint8_t*>(msg_.data()), msg_.size());
while (msg.size() > 0) {
// absorb network data
if (!client->m_transport->ReceivedBytes(msg)) {
// Serious transport problem
client->m_disconnect_flag = true;
continue;
}

if (!client->m_noise->HandshakeComplete()) {
ProcessMaybeSv2Handshake(*client.get(), msg);
} else {
auto sv2_msgs = ReadAndDecryptSv2NetMsgs(*client.get(), msg);
if (client->m_transport->ReceivedMessageComplete()) {
Sv2NetMsg msg = client->m_transport->GetReceivedMessage();

for (auto& m : sv2_msgs)
{
ProcessSv2Message(m, *client.get());
// TODO: push to a queue first
ProcessSv2Message(std::move(msg), *client.get());
// complete = true;
}


// auto msg_ = Span(bytes_received_buf, num_bytes_received);
// Span<std::byte> msg(reinterpret_cast<std::byte*>(msg_.data()), msg_.size());

// if (!client->m_noise->HandshakeComplete()) {
// ProcessMaybeSv2Handshake(*client.get(), msg);
// } else {
// auto sv2_msgs = ReadAndDecryptSv2NetMsgs(*client.get(), msg);

// for (auto& m : sv2_msgs)
// {
// ProcessSv2Message(m, *client.get());
// }
// }
}
} catch (const std::exception& e) {
LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received error when processing client message: %s\n", e.what());
Expand All @@ -271,32 +300,6 @@ void Sv2TemplateProvider::ThreadSv2Handler()
}
}

void Sv2TemplateProvider::ProcessMaybeSv2Handshake(Sv2Client& client, Span<std::byte> buffer)
{
const SessionState state_before = client.m_noise->GetSessionState();
Assume(state_before != SessionState::TRANSPORT);

bool res = client.m_noise->ProcessMaybeHandshake(buffer, /*send=*/false);
if (!res) throw std::runtime_error("Failed to parse Msg E from client\n");

// TODO: consider modifying ReadMsg to optionally return a reply, so
// we don't need to access client.m_noise->GetSessionState()
if (state_before == SessionState::HANDSHAKE_STEP_1) {
// Expect to have read the E msg.
// Expect state transition to have happened
Assume(client.m_noise->GetSessionState() == SessionState::HANDSHAKE_STEP_2);

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send noise handshake reply: ES\n");
std::byte msg_es[INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_LENGTH];
Span<std::byte> msg_es_span(msg_es, INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_LENGTH);
res = client.m_noise->ProcessMaybeHandshake(msg_es_span, /*send=*/true);
if(!res) throw std::runtime_error("Failed to construct Msg ES\n");
if (!SendBuf(client, msg_es_span)) {
throw std::runtime_error("Sv2TemplateProvider::ProcessSv2Message(): Failed to send Msg ES to client\n");
}
}
}

Sv2TemplateProvider::NewWorkSet Sv2TemplateProvider::BuildNewWorkSet(bool future_template, unsigned int coinbase_output_max_additional_size)
{
node::BlockAssembler::Options options;
Expand Down Expand Up @@ -589,78 +592,11 @@ std::vector<CTransactionRef> txs;

bool Sv2TemplateProvider::EncryptAndSendMessage(Sv2Client& client, node::Sv2NetMsg& net_msg)
{
const size_t encrypted_msg_size = Sv2NoiseSession::EncryptedMessageSize(net_msg.m_msg.size());
std::vector<std::byte> buffer(SV2_HEADER_ENCRYPTED_SIZE + encrypted_msg_size, std::byte(0));
Span<std::byte> buffer_span{MakeWritableByteSpan(buffer)};

// Header
DataStream ss_header_plain{};
ss_header_plain << net_msg.m_sv2_header;
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(ss_header_plain));
Span<std::byte> header_encrypted{buffer_span.subspan(0, SV2_HEADER_ENCRYPTED_SIZE)};
client.m_noise->EncryptMessage(ss_header_plain, header_encrypted);

// Payload
Span<std::byte> payload_plain = MakeWritableByteSpan(net_msg.m_msg);
// TODO: truncate very long messages, about 100 bytes at the start and end
// is probably enough for most debugging.
// LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload_plain));
Span<std::byte> payload_encrypted{buffer_span.subspan(SV2_HEADER_ENCRYPTED_SIZE, encrypted_msg_size)};
client.m_noise->EncryptMessage(payload_plain, payload_encrypted);

return SendBuf(client, buffer_span);
// TODO: pass message to transport and get the buffer of stuff to send back
// return SendBuf(client, buffer_span);
return false;
};

std::vector<node::Sv2NetMsg> Sv2TemplateProvider::ReadAndDecryptSv2NetMsgs(Sv2Client& client, Span<std::byte> buffer)
{
Assume(client.m_noise->GetSessionState() == SessionState::TRANSPORT);

size_t bytes_read = 0;
std::vector<node::Sv2NetMsg> sv2_msgs;

while (bytes_read < buffer.size())
{
// Decrypt the header.
Span<std::byte> encrypted_header = Span(&buffer[bytes_read], SV2_HEADER_ENCRYPTED_SIZE);
if (!client.m_noise->DecryptMessage(encrypted_header)) {
LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Failed to decrypt header\n");
client.m_disconnect_flag = true;
break;
}
bytes_read += SV2_HEADER_ENCRYPTED_SIZE;

Span<std::byte> decrypted_header = encrypted_header.subspan(0, SV2_HEADER_PLAIN_SIZE);
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(decrypted_header));

// Decode header
DataStream ss_header{decrypted_header};
node::Sv2NetHeader header;
ss_header >> header;

// Decrypt the payload
size_t expanded_size = Sv2NoiseSession::EncryptedMessageSize(header.m_msg_len);
Span<std::byte> encrypted_payload = Span(&buffer[bytes_read], expanded_size);
if (!client.m_noise->DecryptMessage(encrypted_payload)) {
LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Failed to decrypt message payload\n");
client.m_disconnect_flag = true;
break;
}
Span<std::byte> payload = encrypted_payload.subspan(0, header.m_msg_len);
bytes_read += expanded_size;

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload));

// Add to sv2 message queue
std::vector<uint8_t> msg_payload(payload.size());
std::transform(payload.begin(), payload.end(), msg_payload.begin(),
[](std::byte b) { return static_cast<uint8_t>(b); });

sv2_msgs.emplace_back(std::move(header), std::move(msg_payload));
}

return sv2_msgs;
}

bool Sv2TemplateProvider::SendBuf(const Sv2Client& client, Span<std::byte> buffer) {
size_t total_sent = 0;
try {
Expand Down
Loading

0 comments on commit 229fe0d

Please sign in to comment.