#include "ehs/io/socket/ehc/NetServerCh.h" #include "ehs/io/socket/ehc/NetEnd.h" #include "ehs/io/socket/ehc/NetSys.h" #include "ehs/io/socket/EHC.h" #include "ehs/PRNG.h" namespace ehs { NetServerCh::~NetServerCh() { Shutdown(); } NetServerCh::NetServerCh() : maxEndpoints(0) { } NetServerCh::NetServerCh(Str_8 name, const Version &version, const UInt_64 maxEndpoints) : NetChannel((Str_8 &&)name, version), maxEndpoints(maxEndpoints) { } NetServerCh::NetServerCh(NetServerCh &&server) noexcept : NetChannel((NetChannel &&)server), endpoints((Vector &&)server.endpoints), maxEndpoints(server.maxEndpoints) { for (UInt_64 i = 0; i < endpoints.Size(); i++) endpoints[i]->owner = this; } NetServerCh::NetServerCh(const NetServerCh &server) : NetChannel(server), maxEndpoints(server.maxEndpoints) { } NetServerCh &NetServerCh::operator=(NetServerCh &&server) noexcept { if (this == &server) return *this; NetChannel::operator=((NetChannel &&)server); Shutdown(); endpoints = (Vector &&)server.endpoints; for (UInt_64 i = 0; i < endpoints.Size(); i++) endpoints[i]->owner = this; maxEndpoints = server.maxEndpoints; server.maxEndpoints = 0; return *this; } NetServerCh &NetServerCh::operator=(const NetServerCh &server) { if (this == &server) return *this; NetChannel::operator=(server); Shutdown(); endpoints = {}; maxEndpoints = server.maxEndpoints; return *this; } bool NetServerCh::OnEndpointConnect(NetEnd *endpoint, Serializer payload) { return true; } Serializer NetServerCh::OnEndpointAccepted(NetEnd *endpoint) { return {}; } void NetServerCh::OnEndpointDisconnect(NetEnd *endpoint, Serializer payload) { } void NetServerCh::OnEndpointTimeout(NetEnd *endpoint) { } void NetServerCh::OnEndpointActive(NetEnd *endpoint) { } Serializer NetServerCh::OnShutdown() { return {}; } Serializer OnShutdown() { return {}; } void NetServerCh::Broadcast(const NetStatus endStatus, const bool deltaLocked, const UInt_64 encHashId, const bool ensure, const UInt_64 sysHashId, const UInt_64 opHashId, const Serializer &payload) { if (!GetOwner()->udp.IsValid()) return; for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetStatus() == endStatus) endpoints[i]->Send(deltaLocked, encHashId, ensure, sysHashId, opHashId, payload); } void NetServerCh::Broadcast(const NetStatus endStatus, const bool deltaLocked, const Str_8 &encId, const bool ensure, const Str_8 &sysId, const Str_8 &opId, const Serializer &payload) { Broadcast(endStatus, deltaLocked, encId.Hash_64(), ensure, sysId.Hash_64(), opId.Hash_64(), payload); } bool NetServerCh::HasEndpoint(const NetStatus endStatus, const Char_8 token[64]) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i]->GetStatus() != endStatus) continue; if (Util::Compare(endpoints[i]->token, token, 64)) return true; } return false; } bool NetServerCh::HasEndpoint(const NetStatus endStatus, const UInt_64 hashName) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i]->GetStatus() != endStatus) continue; if (endpoints[i]->GetId() == hashName) return true; } return false; } bool NetServerCh::HasEndpoint(const NetStatus endStatus, const Str_8 &id) const { return HasEndpoint(endStatus, id.Hash_64()); } bool NetServerCh::HasEndpoint(const Char_8 token[64]) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (Util::Compare(endpoints[i]->token, token, 64)) return true; return false; } bool NetServerCh::HasEndpoint(const UInt_64 hashName) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetId() == hashName) return true; return false; } bool NetServerCh::HasEndpoint(const Str_8 &id) const { return HasEndpoint(id.Hash_64()); } bool NetServerCh::HasEndpoint(const Endpoint &endpoint) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetEndpoint().address == endpoint.address && endpoints[i]->GetEndpoint().port == endpoint.port) return true; return false; } NetEnd* NetServerCh::GetEndpoint(const NetStatus endStatus, const Char_8 token[64]) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i]->GetStatus() != endStatus) continue; if (Util::Compare(endpoints[i]->token, token, 64)) return endpoints[i]; } return nullptr; } NetEnd *NetServerCh::GetEndpoint(const NetStatus endStatus, const UInt_64 hashName) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i]->GetStatus() != endStatus) continue; if (endpoints[i]->GetId() == hashName) return endpoints[i]; } return nullptr; } NetEnd *NetServerCh::GetEndpoint(const NetStatus endStatus, const Str_8 &id) const { return GetEndpoint(endStatus, id.Hash_64()); } NetEnd *NetServerCh::GetEndpoint(const Char_8 token[64]) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (Util::Compare(endpoints[i]->token, token, 64)) return endpoints[i]; return nullptr; } NetEnd *NetServerCh::GetEndpoint(const UInt_64 hashName) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetId() == hashName) return endpoints[i]; return nullptr; } NetEnd *NetServerCh::GetEndpoint(const Str_8 &id) const { return GetEndpoint(id.Hash_64()); } NetEnd *NetServerCh::GetEndpoint(const Endpoint &endpoint) const { for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetEndpoint().address == endpoint.address && endpoints[i]->GetEndpoint().port == endpoint.port) return endpoints[i]; return nullptr; } Array NetServerCh::GetEndpoints(const NetStatus endStatus) { Array result(endpoints.Size()); UInt_64 count = 0; for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetStatus() == endStatus) result[count++] = endpoints[i]; result.Resize(count); return result; } UInt_64 NetServerCh::GetEndpointsCount(const NetStatus endStatus) { UInt_64 count = 0; for (UInt_64 i = 0; i < endpoints.Size(); ++i) if (endpoints[i]->GetStatus() == endStatus) ++count; return count; } UInt_64 NetServerCh::GetMaxEndpoints() const { return maxEndpoints; } void NetServerCh::Process(const float &delta, const Endpoint &endpoint, const Header &header, Serializer &payload) { if (header.channelVer != GetVersion()) return; if (!header.ensure && !header.token[0] && header.systemId == internalSys && header.opId == connectOp) { NetEnd* end = new NetEnd(payload.ReadStr(), endpoint); end->owner = this; GenerateToken(end->token); end->SetStatus(NetStatus::PENDING); Serializer sPayload(Endianness::LE); if (!OnEndpointConnect(end, {Endianness::LE, &payload[payload.GetOffset()], payload.Size() - payload.GetOffset()})) { sPayload.WriteStr("Connection rejected."); end->Send(false, true, false, internalSys, rejectedOp, sPayload); return; } endpoints.Push(end); UInt_64 active = GetEndpointsCount(NetStatus::ACTIVE); if (maxEndpoints && active >= maxEndpoints) { end->SetStatus(NetStatus::QUEUED); UpdateQueue(active); sPayload.Write(NetStatus::QUEUED); sPayload.Write(end->GetQueueSlot()); } else { end->SetStatus(NetStatus::ACTIVE); OnEndpointActive(end); sPayload.Write(NetStatus::ACTIVE); sPayload.Write(0); } sPayload.WriteSer(OnEndpointAccepted(end)); end->Send(false, 0, false, internalSys, connectedOp, sPayload); } else if (!header.ensure && header.token[0] && header.systemId == internalSys && header.opId == disconnectOp) { NetEnd* end = GetEndpoint(header.token); if (!end) return; end->Send(false, 0, false, internalSys, disconnectedOp, {}); OnEndpointDisconnect(end, {Endianness::LE, &payload[payload.GetOffset()], payload.Size() - payload.GetOffset()}); RemoveEndpoint(end->token); UpdateQueue(); } else if (!header.ensure && header.token[0] && header.systemId == internalSys && header.opId == pongOp) { NetEnd* end = GetEndpoint(header.token); if (!end) return; end->SetDeltaRate(payload.Read()); end->SendLatency(); } else if (!header.ensure && header.token[0] && header.systemId == internalSys && header.opId == receivedOp) { NetEnd* end = GetEndpoint(header.token); if (!end) return; const UInt_64 msgId = payload.Read(); const UInt_64 fragment = payload.Read(); end->RemoveInsurance(msgId, fragment); } else if (header.token[0]) { NetEnd* end = GetEndpoint(header.token); if (!end) return; if (IsDropPacketsEnabled() && !header.ensure && header.id < end->GetNextRecvId()) { EHS_LOG_INT(LogType::INFO, 6, "Old packet intentionally dropped."); return; } if (header.ensure) { Serializer sPayload(Endianness::LE); sPayload.Write(header.id); sPayload.Write(header.fragment); end->Send(false, 0, false, internalSys, receivedOp, sPayload); } end->AddReceived( header, Serializer<>(Endianness::LE, &payload[payload.GetOffset()], payload.Size() - payload.GetOffset()) ); } else { EHS_LOG_INT(LogType::INFO, 7, "Corrupted packet."); } } void NetServerCh::GenerateToken(Char_8 in[64]) { PRNG_u64 rng(CPU::GetTSC()); for (UInt_64 i = 0; i < 8; ++i) { do ((UInt_64*)in)[i] = rng.Generate(); while (!i && ((UInt_64*)in)[i] == 0); } if (HasEndpoint(in)) GenerateToken(in); } void NetServerCh::UpdateQueue(UInt_64 active) { UInt_64 slot = 0; for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i]->GetStatus() == NetStatus::QUEUED) { if (active < maxEndpoints) { endpoints[i]->SetStatus(NetStatus::ACTIVE); endpoints[i]->SetQueueSlot(0); Serializer payload(Endianness::LE); payload.Write(NetStatus::ACTIVE); payload.Write(0); endpoints[i]->Send(false, true, false, internalSys, statusUpdateOp, payload); OnEndpointActive(endpoints[i]); ++active; } else { if (endpoints[i]->GetQueueSlot() != slot) { Serializer payload(Endianness::LE); payload.Write(NetStatus::QUEUED); payload.Write(slot); endpoints[i]->Send(false, true, false, internalSys, statusUpdateOp, payload); endpoints[i]->SetQueueSlot(slot++); } else { ++slot; } } } } } void NetServerCh::UpdateQueue() { UpdateQueue(GetEndpointsCount(NetStatus::ACTIVE)); } bool NetServerCh::RemoveEndpoint(const Char_8 token[64]) { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (Util::Compare(endpoints[i]->token, token, 64)) { delete endpoints[i]; if (i != endpoints.Size() - 1) endpoints.Swap(i, endpoints.End()); endpoints.Pop(); return true; } } return false; } bool NetServerCh::RemoveEndpoint(const Endpoint &endpoint) { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i]->GetEndpoint().address == endpoint.address && endpoints[i]->GetEndpoint().port == endpoint.port) { delete endpoints[i]; if (i != endpoints.Size() - 1) endpoints.Swap(i, endpoints.End()); endpoints.Pop(); return true; } } return false; } bool NetServerCh::RemoveEndpoint(const NetEnd* const end) { for (UInt_64 i = 0; i < endpoints.Size(); ++i) { if (endpoints[i] == end) { delete endpoints[i]; if (i != endpoints.Size() - 1) endpoints.Swap(i, endpoints.End()); endpoints.Pop(); return true; } } return false; } void NetServerCh::Poll(const float &delta) { UInt_64 i = 0; while (i < endpoints.Size()) { endpoints[i]->Poll(delta); if (endpoints[i]->GetStatus() == NetStatus::PENDING) { if (endpoints[i]->GetTimeout() >= GetMaxTimeout()) { OnEndpointTimeout(endpoints[i]); delete endpoints[i]; if (i != endpoints.Size() - 1) endpoints.Swap(i, endpoints.End()); endpoints.Pop(); continue; } } else { if (endpoints[i]->GetTimeout() >= GetMaxTimeout()) { OnEndpointTimeout(endpoints[i]); delete endpoints[i]; if (i != endpoints.Size() - 1) endpoints.Swap(i, endpoints.End()); endpoints.Pop(); UpdateQueue(); continue; } Vector* frags = endpoints[i]->GetReceived(); UInt_64 f = 0; while (f < frags->Size()) { if (!(*frags)[f].IsComplete()) { ++f; continue; } Packet packet = (*frags)[f].Combine(); NetSys* sys = GetSystem(packet.header.systemId); if (!sys) { ++f; continue; } sys->Execute(this, endpoints[i], packet.header.opId, packet.payload); frags->Swap(f, frags->End()); frags->Pop(); } } ++i; } } void NetServerCh::Shutdown() { Serializer payload = OnShutdown(); for (UInt_64 i = 0; i < endpoints.Size(); i++) { endpoints[i]->Send(false, 0, false, internalSys, disconnectOp, payload); delete endpoints[i]; } } }