Files

650 lines
22 KiB
C++

/*
* This file is part of the Flowee project
* Copyright (C) 2016, 2019-2024 Tom Zander <tom@flowee.org>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "testNWM.h"
#include "streaming/BufferPools.h"
#include <boost/lexical_cast.hpp>
#include <networkmanager/NetworkManager.h>
#include <networkmanager/NetworkManager_p.h>
#include <utils/hash.h>
#include <WorkerThreads.h>
#include <Message.h>
#include <array>
namespace {
void writeLE32ForTest(char *dest, uint32_t value)
{
dest[0] = static_cast<char>(value & 0xFF);
dest[1] = static_cast<char>((value >> 8) & 0xFF);
dest[2] = static_cast<char>((value >> 16) & 0xFF);
dest[3] = static_cast<char>((value >> 24) & 0xFF);
}
QByteArray legacyPacket(const std::array<uint8_t, 4> &magic, const char *command,
const QByteArray &body, bool validChecksum)
{
QByteArray packet(24 + body.size(), 0);
memcpy(packet.data(), magic.data(), magic.size());
strncpy(packet.data() + 4, command, 12);
writeLE32ForTest(packet.data() + 16, static_cast<uint32_t>(body.size()));
const uint256 hash = Hash(body.constData(), body.constData() + body.size());
uint32_t checksum = 0;
memcpy(&checksum, &hash, sizeof(checksum));
if (!validChecksum)
checksum ^= 0x01020304;
writeLE32ForTest(packet.data() + 20, checksum);
memcpy(packet.data() + 24, body.constData(), static_cast<size_t>(body.size()));
return packet;
}
}
TestNWM::TestNWM()
{
srand(time(nullptr));
}
void TestNWM::testBigMessage()
{
auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
std::list<NetworkConnection> stash;
int messageSize = -1;
WorkerThreads threads(5);
NetworkManager server(threads.ioContext());
server.bind(boost::asio::ip::tcp::endpoint(localhost, port), [&stash, &messageSize](NetworkConnection &connection) {
connection.setOnIncomingMessage([&messageSize](const Message &message) {
messageSize = message.body().size();
});
connection.accept();
stash.push_back(std::move(connection));
});
NetworkManager client(threads.ioContext());
EndPoint ep;
ep.announcePort = port;
ep.ipAddress = localhost;
auto con = client.connection(ep);
con.connect();
const int BigSize = 500000;
Streaming::BufferPool pool(BigSize);
for (int i =0; i < BigSize; ++i) {
pool.data()[i] = 0xFF & i;
}
Message message(pool.commit(BigSize), 1);
con.send(message);
/*
* This big message should be split into lots of messages but only
* one should arrive at the other end.
*/
QTRY_COMPARE(messageSize, BigSize);
}
void TestNWM::testRingBuffer()
{
RingBuffer<int> buf(2000);
QCOMPARE(buf.reserved(), 2000); // this makes sure the tests follows the implementation
QCOMPARE(buf.isEmpty(), true);
QCOMPARE(buf.count(), 0);
QCOMPARE(buf.hasItemsMarkedRead(), false);
QCOMPARE(buf.hasUnread(), false);
for (int i = 0; i < 250; ++i) {
buf.append(i);
}
QCOMPARE(buf.hasItemsMarkedRead(), false);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 250);
QCOMPARE(buf.tip(), 0);
QCOMPARE(buf.unreadTip(), 0);
buf.markRead(10);
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 250);
QCOMPARE(buf.tip(), 0);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.unreadTip(), 10);
buf.markAllUnread();
QCOMPARE(buf.hasItemsMarkedRead(), false);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 250);
QCOMPARE(buf.tip(), 0);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.unreadTip(), 0);
buf.markRead(249);
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 250);
QCOMPARE(buf.tip(), 0);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.unreadTip(), 249);
buf.markRead(1);
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 250);
QCOMPARE(buf.tip(), 0);
QCOMPARE(buf.hasUnread(), false);
// don't call unreadTip when hasUnread returns falls. It will assert.
// remove 200 of the 250 items
for (int i = 0; i < 200; ++i) {
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 250 - i);
QCOMPARE(buf.tip(), i);
QCOMPARE(buf.hasUnread(), false);
buf.removeTip();
}
// add 900 items so we now have 950 items wrapping around the buffer.
for (int i = 0; i < 900; ++i) {
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 50 + i);
QCOMPARE(buf.tip(), 200);
QCOMPARE(buf.hasUnread(), i != 0);
if (i > 0)
QCOMPARE(buf.unreadTip(), 1000);
buf.append(1000 + i);
}
buf.markRead(800); // move to absolute pos 50, relative pos 850. Value 1800
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 950);
QCOMPARE(buf.tip(), 200);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.unreadTip(), 1800);
// remove the first 50 items we added.
// this means we have 900 items with value 1000 - 1900 and the read pos is at value 1800
for (int i = 0; i < 50; ++i) {
QCOMPARE(buf.hasItemsMarkedRead(), true);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 950 - i);
QCOMPARE(buf.tip(), 200 + i);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.unreadTip(), 1800);
buf.removeTip();
}
// remove all other items we added.
for (int i = 0; i < 900; ++i) {
QCOMPARE(buf.hasItemsMarkedRead(), i < 800);
QCOMPARE(buf.isEmpty(), false);
QCOMPARE(buf.count(), 900 - i);
QCOMPARE(buf.tip(), 1000 + i);
QCOMPARE(buf.hasUnread(), true);
QCOMPARE(buf.unreadTip(), std::max(1800, 1000 + i));
buf.removeTip();
}
// its empty now
QCOMPARE(buf.hasItemsMarkedRead(), false);
QCOMPARE(buf.isEmpty(), true);
QCOMPARE(buf.count(), 0);
QCOMPARE(buf.hasUnread(), false);
RingBuffer<int> wrap(3);
QCOMPARE(wrap.reserved(), 3);
QCOMPARE(wrap.slotsAvailable(), 3);
QCOMPARE(wrap.isFull(), false);
wrap.append(1);
wrap.append(2);
wrap.append(3);
QCOMPARE(wrap.count(), 3);
QCOMPARE(wrap.slotsAvailable(), 0);
QCOMPARE(wrap.isFull(), true);
wrap.markRead(1);
wrap.removeTip();
QCOMPARE(wrap.count(), 2);
QCOMPARE(wrap.slotsAvailable(), 1);
QCOMPARE(wrap.isFull(), false);
QCOMPARE(wrap.tip(), 2);
wrap.append(4);
QCOMPARE(wrap.count(), 3);
QCOMPARE(wrap.slotsAvailable(), 0);
QCOMPARE(wrap.isFull(), true);
QCOMPARE(wrap.tip(), 2);
}
void TestNWM::testHeaderInt()
{
auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
QMutex writeLock;
std::map<int, int> headerMap;
WorkerThreads threads;
NetworkManager server(threads.ioContext());
std::list<NetworkConnection> stash;
server.bind(boost::asio::ip::tcp::endpoint(localhost, port), [&stash, &headerMap, &writeLock](NetworkConnection &connection) {
connection.setOnIncomingMessage([&headerMap, &writeLock](const Message &message) {
QMutexLocker l(&writeLock);
headerMap = message.headerData();
});
connection.accept();
stash.push_back(std::move(connection));
});
NetworkManager client(threads.ioContext());
auto con = client.connection(EndPoint(localhost, port));
const int MessageSize = 20000;
Streaming::BufferPool pool(MessageSize);
for (int i =0; i < MessageSize; ++i) {
pool.data()[i] = 0xFF & i;
}
Message message(pool.commit(MessageSize), 1);
message.setHeaderInt(11, 312);
message.setHeaderInt(233, 12521);
message.setHeaderInt(1111, 1112);
con.send(message);
QCOMPARE((int) message.headerData().size(), 5); // 3 from above and the service/message ids
QTRY_COMPARE(message.headerData(), headerMap);
}
void TestNWM::testChunkReadQueue()
{
/*
* The NWM does flow control using the outgoing-message-queue size
* This means we might end up pausing processing of incoming traffic in order to
* wait for the outgoing data to be sent.
*
* Lets test that we still manage to send everything.
*
* The way to test this is simply that when we get 10 incoming messages, which generate
* 1000 outgoing messages, then we expect the NWM to stop processing the incoming and
* push a 'send' in between.
*/
auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
std::list<NetworkConnection> connections;
WorkerThreads threads;
NetworkManager receiver(threads.ioContext());
receiver.bind(boost::asio::ip::tcp::endpoint(localhost, port), [&connections](NetworkConnection &connection) {
connection.setMessageQueueSizes(1000, 1000);
connections.push_back(std::move(connection));
NetworkConnection *con = &connections.back();
con->setOnIncomingMessage([con](const Message &message) {
// first send a high prio, those are useful to measure the chunk-size.
con->send(Message(1, 1), NetworkConnection::HighPriority);
// for each incoming connection we send 100.
for (int i = 0; i < 100; ++i) {
con->send(Message(message.serviceId(), message.messageId() + 1));
}
});
con->accept();
});
NetworkManager sender(threads.ioContext());
EndPoint ep;
ep.announcePort = port;
ep.ipAddress = localhost;
auto con = sender.connection(ep);
con.setMessageQueueSizes(1000, 1000);
struct ReplyParser {
int plainMessageCount = 0;
int prioMessageCount = 0;
bool ok = false;
void replyReceived(const Message &message) {
if (message.serviceId() == 1)
++prioMessageCount;
else
++plainMessageCount;
if (!ok && plainMessageCount > 300 && prioMessageCount < 5) {
// Prio messages are sent every flush of the other side, so this is how
// we know that the replies have been chunked. We get the first 4 batches
// (4 + 400 messages) in one go and then we get another such batch and
// a last batch to finish the (10 + 1000) messages count.
// If there was no chunking we'd have gotten all prio messages
// in one (or nothing at all).
ok = true;
}
}
};
ReplyParser parser;
con.setOnIncomingMessage(std::bind(&ReplyParser::replyReceived, &parser, std::placeholders::_1));
con.connect();
// we send 10 messages from sender to receiver.
for (int i = 0; i < 10; ++i) {
con.send(Message(10, 5));
}
QTRY_COMPARE(parser.ok, true);
}
void TestNWM::testAsyncSendQueueFullDoesNotTerminate()
{
const auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
WorkerThreads threads(2);
NetworkManager client(threads.ioContext());
auto con = client.connection(EndPoint(localhost, port));
con.setMessageQueueSizes(11, 3);
QTest::qWait(50); // Let the queue-size update reach the connection strand.
Message message(1, 1);
for (int i = 0; i < 100; ++i)
con.send(message);
QTest::qWait(250); // Queued sends must be dropped, not escape the IO thread.
QVERIFY(true);
}
void TestNWM::testNativeEnvelopeMinimumLength()
{
const auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
std::list<NetworkConnection> stash;
QAtomicInt delivered(0);
QAtomicInt disconnected(0);
WorkerThreads threads(5);
NetworkManager server(threads.ioContext());
server.bind(boost::asio::ip::tcp::endpoint(localhost, port), [&stash, &delivered, &disconnected](NetworkConnection &connection) {
connection.setOnIncomingMessage([&delivered](const Message &) {
delivered.ref();
});
connection.setOnDisconnected([&disconnected](const EndPoint &) {
disconnected.ref();
});
connection.accept();
stash.push_back(std::move(connection));
});
boost::asio::io_context clientContext;
boost::asio::ip::tcp::socket socket(clientContext);
socket.connect(boost::asio::ip::tcp::endpoint(localhost, port));
QByteArray packet(4, 0);
packet[0] = 1; // Native packet length includes these two length bytes.
packet[2] = 8; // Pass the first-packet native probe to exercise the length check.
boost::asio::write(socket, boost::asio::buffer(packet.constData(), packet.size()));
socket.close();
QTRY_COMPARE(disconnected.loadAcquire(), 1);
QCOMPARE(delivered.loadAcquire(), 0);
}
void TestNWM::testNativeParserPacketBounds()
{
const auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
std::list<NetworkConnection> stash;
QAtomicInt delivered(0);
QAtomicInt disconnected(0);
WorkerThreads threads(5);
NetworkManager server(threads.ioContext());
server.bind(boost::asio::ip::tcp::endpoint(localhost, port), [&stash, &delivered, &disconnected](NetworkConnection &connection) {
connection.setOnIncomingMessage([&delivered](const Message &) {
delivered.ref();
});
connection.setOnDisconnected([&disconnected](const EndPoint &) {
disconnected.ref();
});
connection.accept();
stash.push_back(std::move(connection));
});
boost::asio::io_context clientContext;
boost::asio::ip::tcp::socket socket(clientContext);
socket.connect(boost::asio::ip::tcp::endpoint(localhost, port));
QByteArray packet;
packet.append(char(4)); // Packet length includes these two length bytes.
packet.append(char(0));
packet.append(char(8)); // ServiceId tag.
packet.append(char(1)); // ServiceId value, but no HeaderEnd in this packet.
packet.append(char(4)); // Would look like HeaderEnd if the parser reads past packetLength.
packet.append(char(0));
packet.append(char(4));
packet.append(char(0));
boost::asio::write(socket, boost::asio::buffer(packet.constData(), packet.size()));
socket.close();
QTRY_COMPARE(disconnected.loadAcquire(), 1);
QCOMPARE(delivered.loadAcquire(), 0);
}
void TestNWM::testLegacyP2PEnvelopeValidation()
{
const auto localhost = boost::asio::ip::address_v4::loopback();
const std::array<uint8_t, 4> expectedMagic{{0xe3, 0xe1, 0xf3, 0xe8}};
const std::array<uint8_t, 4> wrongMagic{{0xfa, 0xbf, 0xb5, 0xda}};
auto runCase = [&](const QByteArray &packet, int expectedDelivered) {
const int port = std::max(1100, rand() % 32000);
QAtomicInt delivered(0);
std::list<NetworkConnection> stash;
WorkerThreads threads(5);
NetworkManager server(threads.ioContext());
server.setMessageIdLookup({{Api::P2P::Ping, "ping"}});
server.setLegacyNetworkId(std::vector<uint8_t>(expectedMagic.begin(), expectedMagic.end()));
server.bind(boost::asio::ip::tcp::endpoint(localhost, port),
[&stash, &delivered](NetworkConnection &connection) {
connection.setMessageHeaderLegacy(true);
connection.setOnIncomingMessage([&delivered](const Message &) {
delivered.ref();
});
connection.accept();
stash.push_back(std::move(connection));
});
boost::asio::io_context clientContext;
boost::asio::ip::tcp::socket socket(clientContext);
socket.connect(boost::asio::ip::tcp::endpoint(localhost, port));
boost::asio::write(socket, boost::asio::buffer(packet.constData(), packet.size()));
socket.close();
if (expectedDelivered > 0)
QTRY_COMPARE(delivered.loadAcquire(), expectedDelivered);
else
QTest::qWait(200);
QCOMPARE(delivered.loadAcquire(), expectedDelivered);
};
const QByteArray body("12345678", 8);
runCase(legacyPacket(expectedMagic, "ping", body, true), 1);
runCase(legacyPacket(wrongMagic, "ping", body, true), 0);
runCase(legacyPacket(expectedMagic, "ping", body, false), 0);
}
void TestNWM::testEncrypted()
{
WorkerThreads threads(5);
NetworkManager nw(threads.ioContext());
/*
* Certificate creation is easy with openssl:
*
* # private key
* openssl genrsa -out privkey.pem 2048
* # Create a certificate signing request using your new key
* openssl req -new -key privkey.pem -out certreq.csr
* # Self-sign your CSR with your own private key:
* openssl x509 -req -days 3650 -in certreq.csr -signkey privkey.pem -out newcert.pem
*
* We need a DH temp file.
* openssl dhparam -outform PEM -out dh2048.pem 2048
*
* In production all 2048s should likely go to 4096
*/
auto pool = Streaming::pool(1359);
QFile newcert(":/newcert.pem");
QVERIFY(newcert.open(QIODevice::ReadOnly));
QCOMPARE(newcert.size(), 1359);
QCOMPARE(newcert.read(pool->begin(), 1359), 1359);
auto newCertBuf = pool->commit(1359);
QFile privKey(":/privkey.pem");
QVERIFY(privKey.open(QIODevice::ReadOnly));
QCOMPARE(privKey.size(), 1704);
QCOMPARE(privKey.read(pool->begin(), 1704), 1704);
auto privKeyBuf = pool->commit(1704);
auto localhost = boost::asio::ip::address_v4::loopback();
const int port = std::max(1100, rand() % 32000);
logFatal() << "Trying to bind on:" << port;
QAtomicInt gotIncoming(0);
Streaming::ConstBuffer empty;
// bind on a random port
NetworkConnection serverCon;
nw.bindSsl(boost::asio::ip::tcp::endpoint(localhost, port), newCertBuf, privKeyBuf, empty, [&serverCon, &gotIncoming](NetworkConnection &connection) {
logFatal() << "Server: Got incoming connection";
connection.setOnIncomingMessage([&gotIncoming](const Message &m) {
logFatal() << "Server: Got incoming message";
if (m.serviceId() == 15 && m.messageId() == 1)
gotIncoming.ref();
else
logFatal() << "Got unrecognized message";
});
connection.accept();
serverCon = std::move(connection);
});
Message headerOnlyMessage(15, 1);
{
// first try to connect without handshake.
// this will connect and send data and then get
// disconnected by the server because we didn't ssl handshake.
QAtomicInt counter(0);
NetworkManager client(threads.ioContext());
EndPoint ep;
ep.announcePort = port;
ep.ipAddress = localhost;
QCOMPARE(ep.encrypted, false); // make sure this will fail
auto con1 = client.connection(ep);
con1.setOnConnected([](const EndPoint&) { logCritical() << "Con 1 connected"; });
con1.setOnDisconnected([&counter](const EndPoint&) {
logCritical() << "Con 1 disconnected";
counter.ref();
});
con1.setOnError([](int, const boost::system::error_code&) {
QFAIL("No error expected");
});
con1.send(headerOnlyMessage);
// just saying its encrypted doesn't help, it will just fail differently.
// So this one should never connect but fail in the handshake. Which is practically
// the same as con1, accept we never 'connect'
ep.encrypted = true;
auto con2 = client.connection(ep);
QVERIFY(con1.connectionId() != con2.connectionId());
con2.setOnConnected([](const EndPoint&) {
QFAIL("should not connect");
});
con2.setOnDisconnected([](const EndPoint&) {
QFAIL("Hmm, wait what?"); // if we don't connect then we can't disconnect
});
con2.setOnError([&counter](int connectionId, const boost::system::error_code &err) {
logDebug() << connectionId << err.message();
counter.ref();
});
con2.send(headerOnlyMessage);
QTRY_COMPARE(counter.loadAcquire(), 2);
}
NetworkManager client(threads.ioContext());
QAtomicInt gotConnect(0);
gotIncoming = 0;
EndPoint ep;
ep.announcePort = port;
ep.ipAddress = localhost;
ep.encrypted = true;
// set the right data.
auto con = client.connection(ep);
con.setCertificate(newCertBuf);
con.setOnConnected([&gotConnect](const EndPoint&) {
logFatal() << "yay, connect succeeded!";
gotConnect.ref();
});
bool firstConnect = true;
con.setOnDisconnected([&firstConnect](const EndPoint&) {
if (firstConnect)
QFAIL("Did not expect disconnect");
});
con.setOnError([](int connectionId, const boost::system::error_code &err) {
logFatal() << connectionId << err.message();
QFAIL("should not error");
});
con.send(headerOnlyMessage);
QTRY_COMPARE(gotIncoming.loadAcquire(), 1);
// Reconnect the client and see if a new message comes through, which goes towards having
// a proper handshaking logic.
firstConnect = false;
con.disconnect();
gotIncoming = 0;
QTest::qWait(100);
con.connect(); // re-activate the connection.
con.send(headerOnlyMessage);
QTRY_COMPARE(gotIncoming.loadAcquire(), 1);
}
void TestNWM::basic()
{
WorkerThreads threads;
NetworkManager nw(threads.ioContext());
auto ep = nw.endPoint(199);
QVERIFY(ep.hostname.empty());
ep.announcePort = 1212;
auto localhost = boost::asio::ip::address_v4::loopback();
ep.ipAddress = localhost;
auto con = nw.connection(ep);
auto ep2 = nw.endPoint(con.connectionId());
QCOMPARE(ep2.announcePort, 1212);
}
QTEST_MAIN(TestNWM)