Files
EHS/src/io/socket/Socket.cpp

1316 lines
33 KiB
C++

#include "io/socket/Socket.h"
#include "io/socket/Endpoint.h"
#include "system/System.h"
#include "ehs/Encryption.h"
#if defined(EHS_OS_WINDOWS)
#include <WinSock2.h>
#include <WS2tcpip.h>
#elif defined(EHS_OS_LINUX)
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <cerrno>
#endif
namespace ehc
{
const Version Socket::ver(1, 0, 0);
const UInt_64 Socket::internalSys = Str_8::Hash_64("Internal");
const UInt_64 Socket::connectOp = Str_8::Hash_64("Connect");
const UInt_64 Socket::connectedOp = Str_8::Hash_64("Connected");
const UInt_64 Socket::rejectedOp = Str_8::Hash_64("Rejected");
const UInt_64 Socket::disconnectOp = Str_8::Hash_64("Disconnect");
const UInt_64 Socket::disconnectedOp = Str_8::Hash_64("Disconnected");
const UInt_64 Socket::statusUpdateOp = Str_8::Hash_64("StatusUpdate");
const UInt_64 Socket::pingOp = Str_8::Hash_64("Ping");
const UInt_64 Socket::pongOp = Str_8::Hash_64("Pong");
const UInt_64 Socket::latencyOp = Str_8::Hash_64("Latency");
const UInt_64 Socket::receivedOp = Str_8::Hash_64("Received");
Socket::~Socket()
{
UnInitialize();
}
Socket::Socket()
: hdl(EHS_INVALID_SOCKET), type(IP::V4), port(0), bound(false), appVer(0, 0, 0),
disposition(Disposition::UNKNOWN), dropPackets(false), hashId(0), buffer(nullptr), bufferSize(0),
maxEndpoints(0), lastTSC(0), delta(0.0f), maxTimeout(5.0f), resendRate(0.5f), connectedCb(nullptr),
activeCb(nullptr), disconnectedCb(nullptr)
{
AddType("Socket");
}
Socket::Socket(const Version& ver, const Disposition disposition, const Str_8& id, const UInt_64 maxEndpoints)
: hdl(EHS_INVALID_SOCKET), type(IP::V4), port(0), bound(false), appVer(ver), disposition(disposition),
dropPackets(false), id(id), hashId(id.Hash_32()), buffer(nullptr), bufferSize(0),
maxEndpoints(maxEndpoints), lastTSC(CPU::GetTSC()), delta(0.0f), maxTimeout(5.0f), resendRate(0.5f),
connectedCb(nullptr), activeCb(nullptr), disconnectedCb(nullptr)
{
}
Socket::Socket(const Socket& sock)
: BaseObj(sock), hdl(EHS_INVALID_SOCKET), type(sock.type), address(sock.address), port(sock.port), bound(false),
appVer(sock.appVer), disposition(sock.disposition), dropPackets(sock.dropPackets),
id(sock.id), hashId(sock.hashId), buffer(nullptr), bufferSize(0), maxEndpoints(sock.maxEndpoints),
lastTSC(CPU::GetTSC()), delta(0.0f), maxTimeout(sock.maxTimeout), resendRate(sock.resendRate),
connectedCb(nullptr), activeCb(nullptr), disconnectedCb(nullptr)
{
AddType("Socket");
}
Socket& Socket::operator=(const Socket& sock)
{
if (this == &sock)
return *this;
BaseObj::operator=(sock);
hdl = EHS_INVALID_SOCKET;
type = sock.type;
address = sock.address;
port = sock.port;
bound = false;
appVer = sock.appVer;
disposition = sock.disposition;
dropPackets = sock.dropPackets;
id = sock.id;
hashId = sock.hashId;
buffer = nullptr;
bufferSize = 0;
systems = Array<System*>();
endpoints = Vector<Endpoint*>();
maxEndpoints = sock.maxEndpoints;
lastTSC = 0;
delta = 0.0f;
maxTimeout = sock.maxTimeout;
resendRate = sock.resendRate;
connectedCb = nullptr;
activeCb = nullptr;
disconnectedCb = nullptr;
return *this;
}
void Socket::Initialize()
{
if (hdl != EHS_INVALID_SOCKET)
return;
#if defined(EHS_OS_WINDOWS)
WSADATA data = {};
int wsaCode = WSAStartup(MAKEWORD(2, 2), &data);
if (wsaCode)
{
EHS_LOG_INT(LogType::ERR, 0, "WSAStartup failed with the error #" + Str_8::FromNum(wsaCode) + ".");
return;
}
#endif
if (type == IP::V6)
hdl = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
else if (type == IP::V4)
hdl = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
else
return;
if (hdl == EHS_INVALID_SOCKET)
{
UInt_32 code = 0;
#if defined(EHS_OS_WINDOWS)
code = WSAGetLastError();
#elif defined(EHS_OS_LINUX)
code = errno;
#endif
EHS_LOG_INT(LogType::ERR, 1, "Failed to create socket with error #" + Str_8::FromNum(code) + ".");
#if defined(EHS_OS_WINDOWS)
if (WSACleanup() == SOCKET_ERROR)
EHS_LOG_INT(LogType::ERR, 2, "Failed to shutdown WSA with error #" + Str_8::FromNum(WSAGetLastError()) + ".");
#endif
return;
}
if (type == IP::V4)
{
buffer = new Byte[EHS_IPV4_UDP_PAYLOAD];
bufferSize = EHS_IPV4_UDP_PAYLOAD;
}
else if (type == IP::V6)
{
buffer = new Byte[EHS_IPV6_UDP_PAYLOAD];
bufferSize = EHS_IPV6_UDP_PAYLOAD;
}
}
void Socket::UnInitialize()
{
if (hdl == EHS_INVALID_SOCKET)
return;
delete[] buffer;
buffer = nullptr;
bufferSize = 0;
Serializer payload(Endianness::LE);
payload.Write<UInt_64>(0);
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetStatus() != Status::PENDING)
endpoints[i]->Send(false, true, false, internalSys, disconnectOp, payload);
delete endpoints[i];
}
endpoints.Clear();
for (UInt_64 i = 0; i < systems.Size(); ++i)
delete systems[i];
systems.Clear();
Int_32 code = 0;
#if defined(EHS_OS_WINDOWS)
code = closesocket(hdl);
if (code == SOCKET_ERROR)
EHS_LOG_INT(LogType::ERR, 0, "Failed to close socket with error #" + Str_8::FromNum(GetLastError()) + ".");
#elif defined(EHS_OS_LINUX)
code = close(hdl);
if (code == -1)
EHS_LOG_INT(LogType::ERR, 0, "Failed to close socket with error #" + Str_8::FromNum(errno) + ".");
#endif
hdl = EHS_INVALID_SOCKET;
#if defined(EHS_OS_WINDOWS)
if (WSACleanup() == SOCKET_ERROR)
EHS_LOG_INT(LogType::ERR, 1, "Failed to shutdown WSA with error #" + Str_8::FromNum(WSAGetLastError()) + ".");
#endif
bound = false;
}
void Socket::Bind(const Str_8& newAddress, const UInt_16 newPort)
{
if (hdl == EHS_INVALID_SOCKET || bound)
return;
if (type == IP::V6)
Bind_v6(newAddress, newPort);
else if (type == IP::V4)
Bind_v4(newAddress, newPort);
address = newAddress;
port = newPort;
bound = true;
}
void Socket::Connect(const Str_8& address, const UInt_16 port)
{
if (hdl == EHS_INVALID_SOCKET)
return;
Serializer payload(Endianness::LE);
payload.Write(CPU::GetArchitecture());
payload.WriteStr(id);
payload.WriteVersion(ver);
payload.WriteVersion(appVer);
Endpoint* end = new Endpoint(hdl, type, address, port);
end->SetParent(this);
end->Send(false, true, false, "Internal", "Connect", payload);
endpoints.Push(end);
}
bool Socket::Disconnect(const Disposition disposition, const UInt_64 hashId, const Str_8& msg)
{
if (hdl == EHS_INVALID_SOCKET)
return false;
Endpoint* end = GetEndpoint(disposition, hashId);
if (!end)
return false;
Str_8 dcMsg = end->GetId() + " has been disconnected.";
if (msg.Size())
dcMsg += " Reason: " + msg;
EHS_LOG_INT(LogType::INFO, 0, dcMsg);
Serializer<> payload(Endianness::LE);
payload.WriteStr(msg);
end->Send(false, true, false, internalSys, disconnectOp, payload);
return true;
}
bool Socket::Disconnect(const Disposition disposition, const Str_8& id, const Str_8& msg)
{
return Disconnect(disposition, id.Hash_32(), msg);
}
void Socket::Broadcast(const Disposition disposition, const Status status, const bool deltaLocked, const bool encrypted,
const bool ensure, const UInt_64 sysHashId, const UInt_64 opHashId,
const Serializer<>& payload)
{
if (hdl == EHS_INVALID_SOCKET)
return;
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetStatus() == status)
endpoints[i]->Send(deltaLocked, encrypted, ensure, sysHashId, opHashId, payload);
}
}
void Socket::Broadcast(const Disposition disposition, const Status status, const bool deltaLocked, const bool encrypted,
const bool ensure, const Str_8& sysId, const Str_8& opId,
const Serializer<>& payload)
{
Broadcast(disposition, status, deltaLocked, encrypted, ensure, sysId.Hash_64(), opId.Hash_64(), payload);
}
void Socket::Poll()
{
if (hdl == EHS_INVALID_SOCKET)
return;
UInt_64 newTSC = CPU::GetTSC();
delta = (float)(newTSC - lastTSC) / (float)CPU::GetTSC_Freq();
lastTSC = newTSC;
Str_8 rAddress;
UInt_16 rPort = 0;
UInt_16 received = 0;
while ((received = Receive(&rAddress, &rPort, buffer, bufferSize)))
{
Serializer<> payload(Endianness::LE, buffer, received);
bool encrypted = payload.Read<bool>();
if (encrypted)
Encryption::Encrypt_64(payload.Size() - payload.GetOffset(), &payload[payload.GetOffset()]);
payload.SetOffset(0);
Header header = payload.Read<Header>();
if (!header.ensure && header.endpointId && header.system == internalSys && header.op == connectOp)
{
Architecture rArch = payload.Read<Architecture>();
Str_8 rId = payload.ReadStr<Char_8, UInt_64>();
Endpoint* end = new Endpoint(hdl, header.disposition, rArch, rId, type, rAddress, rPort);
end->SetStatus(Status::PENDING);
end->SetParent(this);
Serializer sPayload(Endianness::LE);
Version rVer = payload.ReadVersion();
if (rVer != ver)
{
sPayload.WriteStr<Char_8, UInt_64>("Your Event Horizon Socket Layer version " +
Str_8::FromNum(rVer.major) + "." + Str_8::FromNum(rVer.minor) + "." + Str_8::FromNum(rVer.patch) +
" does not match remote endpoint version " +
Str_8::FromNum(ver.major) + "." + Str_8::FromNum(ver.minor) + "." + Str_8::FromNum(ver.patch) +
". Connection rejected.");
end->Send(false, true, false, internalSys, rejectedOp, sPayload);
continue;
}
Version rAppVer = payload.ReadVersion();
if (rAppVer != appVer)
{
sPayload.WriteStr<Char_8, UInt_64>("Your application version " +
Str_8::FromNum(rAppVer.major) + "." + Str_8::FromNum(rAppVer.minor) + "." + Str_8::FromNum(rAppVer.patch) +
" does not match remote endpoint application version " +
Str_8::FromNum(appVer.major) + "." + Str_8::FromNum(appVer.minor) + "." + Str_8::FromNum(appVer.patch) +
". Connection rejected.");
end->Send(false, true, false, internalSys, rejectedOp, sPayload);
continue;
}
if (HasEndpoint(header.disposition, header.endpointId))
{
if (header.disposition == Disposition::SERVICE)
{
sPayload.WriteStr<Char_8, UInt_64>(
"The service, \"" + end->GetId() + "\", is taken. Connection rejected.");
end->Send(false, true, false, internalSys, rejectedOp, sPayload);
continue;
}
else
{
sPayload.WriteStr<Char_8, UInt_64>(
"The username, \"" + end->GetId() + "\", is taken. Connection rejected.");
end->Send(false, true, false, internalSys, rejectedOp, sPayload);
continue;
}
}
if (connectedCb && !connectedCb(this, end))
{
sPayload.WriteStr<Char_8, UInt_64>("Connection rejected.");
end->Send(false, true, false, internalSys, rejectedOp, sPayload);
continue;
}
endpoints.Push(end);
sPayload.Write(CPU::GetArchitecture());
sPayload.WriteStr(id);
UInt_64 active = GetEndpointsCount(Disposition::ENDPOINT, Status::ACTIVE);
if (maxEndpoints && active >= maxEndpoints)
{
end->SetStatus(Status::IN_LOCAL_QUEUE);
UpdateQueue(active);
sPayload.Write(Status::IN_REMOTE_QUEUE);
sPayload.Write(end->GetQueueSlot());
EHS_LOG_INT(LogType::INFO, 1, end->GetId() + " connected and is in queue slot " + end->GetQueueSlot() + ".");
}
else
{
end->SetStatus(Status::ACTIVE);
if (activeCb)
activeCb(this, end);
sPayload.Write(Status::ACTIVE);
sPayload.Write(0);
EHS_LOG_INT(LogType::INFO, 1, end->GetId() + " connected.");
}
end->Send(false, true, false, internalSys, connectedOp, sPayload);
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == connectedOp)
{
Endpoint* end = GetEndpoint(rAddress, rPort);
if (!end || end->GetStatus() != Status::PENDING)
continue;
Architecture arch = payload.Read<Architecture>();
Str_8 id = payload.ReadStr<Char_8, UInt_64>();
*end = Endpoint(hdl, header.disposition, arch, id, type, rAddress, rPort);
end->SetStatus(payload.Read<Status>());
end->SetQueueSlot(payload.Read<UInt_64>());
end->SetParent(this);
if (connectedCb)
connectedCb(this, end);
Str_8 msg = "Successfully connected to " + end->GetId();
if (end->GetStatus() == Status::IN_REMOTE_QUEUE)
msg += " and in queue slot " + Str_8::FromNum(end->GetQueueSlot()) + ".";
else
msg += ".";
EHS_LOG_INT(LogType::INFO, 2, msg);
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == rejectedOp)
{
if (!RemoveEndpoint(rAddress, rPort))
continue;
Str_8 msg = payload.ReadStr<Char_8, UInt_64>();
if (msg.Size())
EHS_LOG_INT(LogType::INFO, 3, msg);
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == disconnectOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
end->Send(false, true, false, internalSys, disconnectedOp, {});
if (disconnectedCb)
disconnectedCb(this, end);
Str_8 dcMsg;
if (header.disposition == Disposition::SERVICE)
dcMsg = "You have been disconnected from, \"" + end->GetId() + "\".";
else
dcMsg = end->GetId() + " has disconnected.";
Str_8 msg = payload.ReadStr<Char_8, UInt_64>();
if (msg.Size())
dcMsg += " Reason: " + msg;
EHS_LOG_INT(LogType::INFO, 4, dcMsg);
RemoveEndpoint(header.disposition, end->GetHashId());
UpdateQueue();
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == disconnectedOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
if (disconnectedCb)
disconnectedCb(this, end);
RemoveEndpoint(end);
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == statusUpdateOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
Status newStatus = payload.Read<Status>();
UInt_64 newSlot = payload.Read<UInt_64>();
if (end->GetStatus() == Status::ACTIVE)
{
if (activeCb)
activeCb(this, end);
EHS_LOG_INT(LogType::INFO, 5, "Your connection status to " + end->GetId() + " has now become active.");
}
else if (end->GetStatus() == Status::IN_REMOTE_QUEUE && newStatus == Status::IN_REMOTE_QUEUE)
{
EHS_LOG_INT(LogType::INFO, 5, "Your queue slot for " + end->GetId() + " is now " + newSlot + ".");
}
end->SetStatus(newStatus);
end->SetQueueSlot(newSlot);
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == pingOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
end->SetDeltaRate(payload.Read<float>());
end->Pong(delta);
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == pongOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
end->SetDeltaRate(payload.Read<float>());
end->SendLatency();
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == latencyOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
end->SetLatency(payload.Read<float>());
}
else if (!header.ensure && header.endpointId && header.system == internalSys && header.op == receivedOp)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
UInt_64 msgId = payload.Read<UInt_64>();
UInt_64 fragment = payload.Read<UInt_64>();
end->RemoveInsurance(msgId, fragment);
}
else if (header.endpointId)
{
Endpoint* end = GetEndpoint(header.disposition, header.endpointId);
if (!end)
continue;
if (dropPackets && !header.ensure && header.id < end->GetNextRecvId())
{
EHS_LOG_INT(LogType::INFO, 6, "Old packet intentionally dropped.");
continue;
}
if (header.ensure)
{
Serializer sPayload(Endianness::LE);
sPayload.Write(header.id);
sPayload.Write(header.fragment);
end->Send(false, true, false, internalSys, receivedOp, sPayload);
}
end->AddReceived(
header,
Serializer<>(Endianness::LE, &payload[payload.GetOffset()], payload.Size() - payload.GetOffset())
);
}
else
{
EHS_LOG_INT(LogType::INFO, 7, "Corrupted packet.");
}
}
PollEndpoints(endpoints);
}
bool Socket::IsInitialized() const
{
return hdl != EHS_INVALID_SOCKET;
}
void Socket::SetAddressType(const IP newType)
{
if (hdl != EHS_INVALID_SOCKET)
return;
type = newType;
}
IP Socket::GetAddressType() const
{
return type;
}
Str_8 Socket::GetAddress() const
{
return address;
}
UInt_16 Socket::GetPort() const
{
return port;
}
bool Socket::IsBound() const
{
return bound;
}
Version Socket::GetVersion() const
{
return ver;
}
Version Socket::GetAppVersion() const
{
return appVer;
}
Disposition Socket::GetDisposition() const
{
return disposition;
}
void Socket::EnableDropPackets(const bool enable)
{
dropPackets = enable;
}
bool Socket::IsDropPacketsEnabled() const
{
return dropPackets;
}
Str_8 Socket::GetId() const
{
return id;
}
UInt_64 Socket::GetHashId() const
{
return hashId;
}
bool Socket::HasSystem(const UInt_64 hashId) const
{
if (internalSys == hashId)
return true;
for (UInt_64 i = 0; i < systems.Size(); ++i)
if (systems[i]->GetHashId() == hashId)
return true;
return false;
}
bool Socket::HasSystem(const Str_8& id) const
{
return HasSystem(id.Hash_64());
}
bool Socket::AddSystem(System* sys)
{
if (HasSystem(sys->GetHashId()))
return false;
systems.Push(sys);
return true;
}
System* Socket::GetSystem(const UInt_64 hashId)
{
for (UInt_64 i = 0; i < systems.Size(); ++i)
if (systems[i]->GetHashId() == hashId)
return systems[i];
return nullptr;
}
System* Socket::GetSystem(const Str_8& id)
{
return GetSystem(id.Hash_32());
}
bool Socket::HasEndpoint(const Disposition disposition, const Status status, const UInt_64 hashId) const
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetStatus() != status)
continue;
if (endpoints[i]->GetHashId() == hashId)
return true;
}
return false;
}
bool Socket::HasEndpoint(const Disposition disposition, const Status status, const Str_8& id) const
{
return HasEndpoint(disposition, status, id.Hash_32());
}
bool Socket::HasEndpoint(const Disposition disposition, const UInt_64 hashId) const
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetHashId() == hashId)
return true;
}
return false;
}
bool Socket::HasEndpoint(const Disposition disposition, const Str_8& id) const
{
return HasEndpoint(disposition, id.Hash_64());
}
bool Socket::HasEndpoint(const Str_8& address, const UInt_16 port) const
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
if (endpoints[i]->GetAddress() == address && endpoints[i]->GetPort() == port)
return true;
return false;
}
Endpoint* Socket::GetEndpoint(const Disposition disposition, const Status status, const UInt_64 hashId)
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetStatus() != status)
continue;
if (endpoints[i]->GetHashId() == hashId)
return endpoints[i];
}
return nullptr;
}
Endpoint* Socket::GetEndpoint(const Disposition disposition, const Status status, const Str_8& id)
{
return GetEndpoint(disposition, status, id.Hash_32());
}
Endpoint* Socket::GetEndpoint(const Disposition disposition, const UInt_64 hashId)
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetHashId() == hashId)
return endpoints[i];
}
return nullptr;
}
Endpoint* Socket::GetEndpoint(const Disposition disposition, const Str_8& id)
{
return GetEndpoint(disposition, id.Hash_32());
}
Endpoint* Socket::GetEndpoint(const Str_8& address, const UInt_16 port)
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
if (endpoints[i]->GetAddress() == address && endpoints[i]->GetPort() == port)
return endpoints[i];
return nullptr;
}
Array<Endpoint*> Socket::GetEndpoints(const Disposition disposition, const Status status)
{
Array<Endpoint*> result(endpoints.Size());
UInt_64 count = 0;
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetStatus() == status)
result[count++] = endpoints[i];
}
result.Resize(count);
return result;
}
Array<Endpoint*> Socket::GetEndpoints(const Disposition disposition)
{
Array<Endpoint*> result(endpoints.Size());
UInt_64 count = 0;
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
if (endpoints[i]->GetDisposition() == disposition)
result[count++] = endpoints[i];
result.Resize(count);
return result;
}
UInt_64 Socket::GetEndpointsCount(const Disposition disposition, const Status status)
{
UInt_64 count = 0;
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetStatus() == status)
++count;
}
return count;
}
UInt_64 Socket::GetEndpointsCount(const Disposition disposition)
{
UInt_64 count = 0;
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
if (endpoints[i]->GetDisposition() == disposition)
++count;
return count;
}
UInt_64 Socket::GetMaxEndpoints() const
{
return maxEndpoints;
}
void Socket::SetBlocking(const bool blocking)
{
if (hdl == EHS_INVALID_SOCKET)
{
EHS_LOG_INT(LogType::ERR, 0, "Attempted to toggle blocking while socket is not initialized.");
return;
}
#if defined(EHS_OS_WINDOWS)
u_long r = (u_long)!blocking;
int result = ioctlsocket(hdl, FIONBIO, &r);
if (result != NO_ERROR)
EHS_LOG_INT(LogType::ERR, 1, "Failed to toggle non-blocking mode with error #" + Str_8::FromNum(result) + ".");
#elif defined(EHS_OS_LINUX)
if (fcntl(hdl, F_SETFL, O_NONBLOCK, blocking) == -1)
EHS_LOG_INT(LogType::ERR, 1, "Failed to toggle non-blocking mode with error #" + Str_8::FromNum(errno) + ".");
#endif
}
bool Socket::IsBlocking() const
{
#if defined(EHS_OS_WINDOWS)
u_long r = 0;
if (ioctlsocket(hdl, FIONREAD, &r) == SOCKET_ERROR)
EHS_LOG_INT(LogType::ERR, 0, "Failed to retrieve socket info.");
return (bool)r;
#elif defined(EHS_OS_LINUX)
return (bool)fcntl(hdl, F_GETFL, O_NONBLOCK);
#else
return true;
#endif
}
void Socket::SetMaxTimeout(const float seconds)
{
maxTimeout = seconds;
}
float Socket::GetMaxTimeout() const
{
return maxTimeout;
}
void Socket::SetResendRate(const float seconds)
{
resendRate = seconds;
}
float Socket::GetResendRate() const
{
return resendRate;
}
void Socket::SetConnectedCb(bool (*newCb)(Socket*, Endpoint*))
{
connectedCb = newCb;
}
void Socket::SetActiveCb(void (*newCb)(Socket*, Endpoint*))
{
activeCb = newCb;
}
void Socket::SetDisconnectedCb(void (*newCb)(Socket*, Endpoint*))
{
disconnectedCb = newCb;
}
void Socket::UpdateQueue(UInt_64 active)
{
UInt_64 slot = 0;
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetStatus() == Status::IN_LOCAL_QUEUE)
{
if (active < maxEndpoints)
{
endpoints[i]->SetStatus(Status::ACTIVE);
endpoints[i]->SetQueueSlot(0);
Serializer payload(Endianness::LE);
payload.Write(Status::ACTIVE);
payload.Write(0);
endpoints[i]->Send(false, true, false, internalSys, statusUpdateOp, payload);
if (activeCb)
activeCb(this, endpoints[i]);
++active;
}
else
{
if (endpoints[i]->GetQueueSlot() != slot)
{
Serializer payload(Endianness::LE);
payload.Write(Status::IN_REMOTE_QUEUE);
payload.Write(slot);
endpoints[i]->Send(false, true, false, internalSys, statusUpdateOp, payload);
endpoints[i]->SetQueueSlot(slot++);
}
else
{
++slot;
}
}
}
}
}
void Socket::UpdateQueue()
{
UpdateQueue(GetEndpointsCount(Disposition::ENDPOINT, Status::ACTIVE));
}
bool Socket::RemoveEndpoint(const Disposition disposition, const UInt_64 hashId)
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetDisposition() != disposition)
continue;
if (endpoints[i]->GetHashId() == hashId)
{
delete endpoints[i];
if (i != endpoints.Size() - 1)
endpoints.Swap(i, endpoints.End());
endpoints.Pop();
return true;
}
}
return false;
}
bool Socket::RemoveEndpoint(const Str_8& address, const UInt_16 port)
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i]->GetAddress() == address && endpoints[i]->GetPort() == port)
{
delete endpoints[i];
if (i != endpoints.Size() - 1)
endpoints.Swap(i, endpoints.End());
endpoints.Pop();
return true;
}
}
return false;
}
bool Socket::RemoveEndpoint(const Endpoint* const end)
{
for (UInt_64 i = 0; i < endpoints.Size(); ++i)
{
if (endpoints[i] == end)
{
delete endpoints[i];
if (i != endpoints.Size() - 1)
endpoints.Swap(i, endpoints.End());
endpoints.Pop();
return true;
}
}
return false;
}
void Socket::PollEndpoints(Vector<Endpoint*>& endpoints)
{
UInt_64 i = 0;
while (i < endpoints.Size())
{
endpoints[i]->Poll(delta);
if (endpoints[i]->GetStatus() == Status::PENDING)
{
if (endpoints[i]->GetTimeout() >= maxTimeout)
{
EHS_LOG_INT(LogType::INFO, 0, "Failed to connect to, \"" + endpoints[i]->GetAddress() + ":" + Str_8::FromNum(endpoints[i]->GetPort()) + "\".");
delete endpoints[i];
if (i != endpoints.Size() - 1)
endpoints.Swap(i, endpoints.End());
endpoints.Pop();
continue;
}
}
else
{
if (endpoints[i]->GetTimeout() >= maxTimeout)
{
EHS_LOG_INT(LogType::INFO, 6, endpoints[i]->GetId() + " timed out.");
if (disconnectedCb)
disconnectedCb(this, endpoints[i]);
delete endpoints[i];
if (i != endpoints.Size() - 1)
endpoints.Swap(i, endpoints.End());
endpoints.Pop();
UpdateQueue();
continue;
}
Vector<Fragments>* frags = endpoints[i]->GetReceived();
UInt_64 f = 0;
while (f < frags->Size())
{
if (!(*frags)[f].IsComplete())
{
++f;
continue;
}
Packet packet = (*frags)[f].Combine();
System* sys = GetSystem(packet.header.system);
if (!sys)
{
++f;
continue;
}
sys->Execute(this, endpoints[i], packet.header.op, packet.payload);
frags->Swap(f, frags->End());
frags->Pop();
}
}
++i;
}
}
void Socket::Bind_v6(const Str_8& address, const UInt_16 port)
{
Int_32 code = 0;
sockaddr_in6 result = {};
result.sin6_family = AF_INET6;
result.sin6_port = htons(port);
if (address.Size())
{
Int_32 code = inet_pton(AF_INET6, address, &result.sin6_addr);
if (!code)
{
EHS_LOG_INT(LogType::ERR, 0, "The given address, \"" + address + "\" is not valid.");
return;
}
else if (code == -1)
{
Int_32 dCode = 0;
#if defined(EHS_OS_WINDOWS)
dCode = WSAGetLastError();
#elif defined(EHS_OS_LINUX)
dCode = errno;
#endif
EHS_LOG_INT(LogType::ERR, 1, "Failed to convert address with error #" + Str_8::FromNum(dCode) + ".");
return;
}
}
else
{
result.sin6_addr = in6addr_any;
}
code = bind(hdl, (sockaddr*)&result, sizeof(sockaddr_in6));
#if defined(EHS_OS_WINDOWS)
if (code == SOCKET_ERROR)
{
EHS_LOG_INT(LogType::ERR, 2, "Failed to bind socket with error #" + Str_8::FromNum(WSAGetLastError()) + ".");
return;
}
#elif defined(EHS_OS_LINUX)
if (code == -1)
{
EHS_LOG_INT(LogType::ERR, 2, "Failed to bind socket with error #" + Str_8::FromNum(errno) + ".");
return;
}
#endif
}
void Socket::Bind_v4(const Str_8& address, const UInt_16 port)
{
Int_32 code = 0;
sockaddr_in result = {};
result.sin_family = AF_INET;
result.sin_port = htons(port);
if (address.Size())
{
code = inet_pton(AF_INET, address, &result.sin_addr);
if (!code)
{
EHS_LOG_INT(LogType::ERR, 0, "The given address, \"" + address + "\" is not valid.");
return;
}
else if (code == -1)
{
Int_32 dCode = 0;
#if defined(EHS_OS_WINDOWS)
dCode = WSAGetLastError();
#elif defined(EHS_OS_LINUX)
dCode = errno;
#endif
EHS_LOG_INT(LogType::ERR, 1, "Failed to convert address with error #" + Str_8::FromNum(dCode) + ".");
return;
}
}
else
{
result.sin_addr.S_un.S_addr = htonl(INADDR_ANY);
}
code = bind(hdl, (sockaddr*)&result, sizeof(sockaddr_in));
#if defined(EHS_OS_WINDOWS)
if (code == SOCKET_ERROR)
{
EHS_LOG_INT(LogType::ERR, 2, "Failed to bind socket with error #" + Str_8::FromNum(WSAGetLastError()) + ".");
return;
}
#elif defined(EHS_OS_LINUX)
if (code == -1)
{
EHS_LOG_INT(LogType::ERR, 2, "Failed to bind socket with error #" + Str_8::FromNum(errno) + ".");
return;
}
#endif
}
UInt_16 Socket::Receive(Str_8* address, UInt_16* port, Byte* const data, const UInt_16 size)
{
if (hdl == EHS_INVALID_SOCKET)
{
EHS_LOG_INT(LogType::ERR, 0, "Attempted to receive while socket is not initialized.");
return 0;
}
if (type == IP::V4 && size > EHS_IPV4_UDP_PAYLOAD)
{
EHS_LOG_INT(LogType::ERR, 1, "Attempted to receive with a buffer size of, \"" + Str_8::FromNum(size)
+ "\", that exceeds, \"" + Str_8::FromNum(EHS_IPV4_UDP_PAYLOAD) + ".");
}
sockaddr_in6 remote = {};
UInt_32 addrLen = type == IP::V6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in);
SInt_64 received = 0;
#if defined(EHS_OS_WINDOWS)
received = recvfrom(hdl, (char*)data, (int)size, 0, (sockaddr*)&remote, (int*)&addrLen);
if (received == SOCKET_ERROR)
{
int code = WSAGetLastError();
if (code == WSAEMSGSIZE)
{
UnInitialize();
EHS_LOG_INT(LogType::ERR, 2, "The buffer size, \"" + Str_8::FromNum(size) + "\" is too small.");
}
else if (code != WSAECONNRESET && code != WSAEWOULDBLOCK)
{
UnInitialize();
EHS_LOG_INT(LogType::ERR, 3, "Failed to receive with error #" + Str_8::FromNum(code) + ".");
}
return 0;
}
#elif defined(EHS_OS_LINUX)
received = recvfrom(hdl, data, size, 0, (sockaddr*)&remote, &addrLen);
if (received == -1)
{
int code = errno;
if (code != ECONNRESET && code != EWOULDBLOCK)
{
UnInitialize();
EHS_LOG_INT(LogType::ERR, 2, "Failed to receive with error #" + Str_8::FromNum(code) + ".");
}
return 0;
}
#endif
if (addrLen == sizeof(sockaddr_in6))
{
char tmpAddr[INET6_ADDRSTRLEN];
if (!inet_ntop(remote.sin6_family, &remote.sin6_addr, tmpAddr, INET6_ADDRSTRLEN))
{
Int_32 code = 0;
#if defined(EHS_OS_WINDOWS)
code = WSAGetLastError();
#elif defined(EHS_OS_LINUX)
code = errno;
#endif
EHS_LOG_INT(LogType::ERR, 2, "Failed to convert IPv6 address with error #" + Str_8::FromNum(code) + ".");
return received;
}
*address = tmpAddr;
*port = ntohs(remote.sin6_port);
}
else if (addrLen == sizeof(sockaddr_in))
{
char tmpAddr[INET_ADDRSTRLEN];
if (!inet_ntop(((sockaddr_in*)&remote)->sin_family, &((sockaddr_in*)&remote)->sin_addr, tmpAddr, INET_ADDRSTRLEN))
{
Int_32 code = 0;
#if defined(EHS_OS_WINDOWS)
code = WSAGetLastError();
#elif defined(EHS_OS_LINUX)
code = errno;
#endif
EHS_LOG_INT(LogType::ERR, 2, "Failed to convert IPv4 address with error #" + Str_8::FromNum(code) + ".");
return (UInt_16)received;
}
*address = tmpAddr;
*port = ntohs(((sockaddr_in*)&remote)->sin_port);
}
return (UInt_16)received;
}
}