#include "ehs/io/socket/BaseTCP.h" #include "ehs/Log.h" namespace ehs { BaseTCP::BaseTCP() : addrType(AddrType::IPV6), localPort(0), remotePort(0), connection(false), bound(false), listening(false), connected(false) { } BaseTCP::BaseTCP(const AddrType addrType) : addrType(addrType), localPort(0), remotePort(0), connection(false), bound(false), listening(false), connected(false) { } BaseTCP::BaseTCP(BaseTCP&& tcp) noexcept : addrType(tcp.addrType), localAddr(std::move(tcp.localAddr)), localPort(tcp.localPort), remoteHostName(std::move(tcp.remoteHostName)), remoteAddr(std::move(tcp.remoteAddr)), remotePort(tcp.remotePort), connection(tcp.connection), bound(tcp.bound), listening(tcp.listening), connected(tcp.connected) { } BaseTCP::BaseTCP(const BaseTCP& tcp) : addrType(tcp.addrType), localPort(0), remotePort(0), connection(false), bound(false), listening(false), connected(false) { } BaseTCP& BaseTCP::operator=(BaseTCP&& tcp) noexcept { if (this == &tcp) return *this; addrType = tcp.addrType; localAddr = std::move(tcp.localAddr); localPort = tcp.localPort; remoteHostName = std::move(tcp.remoteHostName); remoteAddr = std::move(tcp.remoteAddr); remotePort = tcp.remotePort; connection = tcp.connection; bound = tcp.bound; listening = tcp.listening; connected = tcp.connected; tcp.addrType = AddrType::IPV6; tcp.localPort = 0; tcp.remotePort = 0; tcp.connection = false; tcp.bound = false; tcp.listening = false; tcp.connected = false; return *this; } BaseTCP& BaseTCP::operator=(const BaseTCP& tcp) { if (this == &tcp) return *this; addrType = tcp.addrType; localAddr = Str_8(); localPort = 0; remoteHostName = Str_8(); remoteAddr = Str_8(); remotePort = 0; connection = false; bound = false; listening = false; connected = false; return *this; } void BaseTCP::SendStr(const Str_8& str) { if (!IsValid()) return; UInt_64 offset = 0; UInt_64 size = str.Size(true); while (offset < size) { UInt_64 sent = Send((Byte*)&str[offset], size - offset); if (!sent) { EHS_LOG_INT(LogType::ERR, 0, "Failed to send data."); return; } offset += sent; } } void BaseTCP::SendRes(const Response& res) { if (!IsValid()) return; SendStr(res.FormResult()); } void BaseTCP::SendReq(Request& req) { if (!IsValid()) return; req.AddToHeader("Host", remoteHostName); SendStr(req.FormResult()); } Response BaseTCP::RecvRes() { if (!IsValid()) return {}; Str_8 header = RecvHeader(); if (!header.Size()) return {}; Response response(header); Str_8 encoding = response.GetHeader("Transfer-Encoding"); if (!encoding.Size()) { int bodySize = response.GetHeader("content-length").ToDecimal(); if (!bodySize) return response; response.SetBody(RecvBody(bodySize)); } else if (encoding == "chunked") { Str_8 body; UInt_64 chunkSize = RecvChunkSize(); while (chunkSize) { body += RecvChunk(chunkSize); chunkSize = RecvChunkSize(); } response.SetBody(body); } return response; } Request BaseTCP::RecvReq() { if (!IsValid()) return {}; Str_8 header = RecvHeader(); if (!header.Size()) return {}; Request request(header); if (request.GetVerb() == Verb::GET) return request; Str_8 encoding = request.GetHeader("Transfer-Encoding"); if (!encoding.Size()) { int bodySize = request.GetHeader("Content-Length").ToDecimal(); if (!bodySize) return request; request.SetBody(RecvBody(bodySize)); } else if (encoding == "chunked") { Str_8 body; UInt_64 chunkSize = RecvChunkSize(); while (chunkSize) { body += RecvChunk(chunkSize); chunkSize = RecvChunkSize(); } request.SetBody(body); } return request; } AddrType BaseTCP::GetAddressType() const { return addrType; } Str_8 BaseTCP::GetLocalAddress() const { return localAddr; } unsigned short BaseTCP::GetLocalPort() const { return localPort; } Str_8 BaseTCP::GetRemoteAddress() const { return remoteAddr; } unsigned short BaseTCP::GetRemotePort() const { return remotePort; } bool BaseTCP::IsConnection() const { return connection; } bool BaseTCP::IsBound() const { return bound; } bool BaseTCP::IsListening() const { return listening; } bool BaseTCP::IsConnected() const { return connected; } Str_8 BaseTCP::RecvHeader() { Byte buffer[MaxHeaderSize]; UInt_64 offset = 0; while (true) { UInt_64 received = Receive(&buffer[offset], 1); if (!received) { return {}; } else if (buffer[offset] == '\n' && offset - 3 && buffer[offset - 1] == '\r' && buffer[offset - 2] == '\n' && buffer[offset - 3] == '\r') { offset -= 3; break; } offset += received; } return {(Char_8*)buffer, (UInt_64)offset}; } Str_8 BaseTCP::RecvBody(const UInt_64 contentLength) { Str_8 buffer(contentLength); UInt_64 offset = 0; while (offset < contentLength) { UInt_64 received = Receive((Byte*)&buffer[offset], contentLength - offset); if (!received) { EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data."); return {}; } offset += received; } return buffer; } UInt_64 BaseTCP::RecvChunkSize() { Str_8 hexSize(10); UInt_64 offset = 0; bool cr = false; while (true) { UInt_64 received = Receive((Byte*)&hexSize[offset], 1); if (!received) { EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data."); return 0; } else if (hexSize[offset] == '\r') cr = true; else if (cr && hexSize[offset] == '\n') break; ++offset; } if (hexSize[0] == '0') Receive((Byte*)&hexSize[offset + 1], 2); hexSize.Resize(offset - 1); return hexSize.HexToNum(); } Str_8 BaseTCP::RecvChunk(const UInt_64 chunkSize) { Str_8 buffer(chunkSize + 2); UInt_64 offset = 0; while (offset < chunkSize + 2) { UInt_64 received = Receive((Byte*)&buffer[offset], chunkSize + 2 - offset); if (!received) { EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data."); return {}; } offset += received; } buffer.Resize(offset - 2); return buffer; } }