diff --git a/include/ehs/io/socket/BaseTCP.h b/include/ehs/io/socket/BaseTCP.h index b5e521f..4010bcb 100644 --- a/include/ehs/io/socket/BaseTCP.h +++ b/include/ehs/io/socket/BaseTCP.h @@ -150,6 +150,8 @@ namespace ehs virtual bool IsIPv6Only() const = 0; + virtual void SetReuse(const bool &enabled) = 0; + /// Retrieves whether or not this socket was initialized. /// @returns The result. virtual bool IsValid() const = 0; diff --git a/include/ehs/io/socket/SSL.h b/include/ehs/io/socket/SSL.h index ff6b528..43074e4 100644 --- a/include/ehs/io/socket/SSL.h +++ b/include/ehs/io/socket/SSL.h @@ -23,7 +23,7 @@ namespace ehs SSL(); - SSL(const IP type); + SSL(const IP &type); SSL(TCP&& tcp) noexcept; @@ -39,6 +39,8 @@ namespace ehs void Bind(const Str_8& address, unsigned short port) override; + void Listen() override; + SSL* Accept() override; void Connect(const Str_8& address, const UInt_16 port) override; diff --git a/include/ehs/io/socket/TCP_BSD.h b/include/ehs/io/socket/TCP_BSD.h index 09ae2eb..605259d 100644 --- a/include/ehs/io/socket/TCP_BSD.h +++ b/include/ehs/io/socket/TCP_BSD.h @@ -84,6 +84,8 @@ namespace ehs bool IsIPv6Only() const override; + void SetReuse(const bool &value) override; + bool IsValid() const override; private: diff --git a/src/io/socket/SSL.cpp b/src/io/socket/SSL.cpp index a2fd4e3..6c4134f 100644 --- a/src/io/socket/SSL.cpp +++ b/src/io/socket/SSL.cpp @@ -29,7 +29,7 @@ namespace ehs { } - SSL::SSL(const IP type) + SSL::SSL(const IP &type) : TCP(type), ctx(nullptr), sslHdl(nullptr) { SSL::Initialize(); @@ -102,9 +102,14 @@ namespace ehs if (bound) return; - OpenSSL_add_ssl_algorithms(); - SSL_load_error_strings(); - ctx = SSL_CTX_new(TLS_server_method()); + TCP::Bind(address, port); + } + + void SSL::Listen() + { + OpenSSL_add_ssl_algorithms(); + SSL_load_error_strings(); + ctx = SSL_CTX_new(TLS_server_method()); SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); @@ -118,18 +123,20 @@ namespace ehs SSL_CTX_set_ecdh_auto(ctx, 1); #endif - sslHdl = SSL_new(ctx); - SSL_set_fd(sslHdl, hdl); + sslHdl = SSL_new(ctx); + SSL_set_fd(sslHdl, hdl); - TCP::Bind(address, port); + TCP::Listen(); } - SSL* SSL::Accept() + SSL* SSL::Accept() { if (!bound) return nullptr; TCP* tcp = TCP::Accept(); + if (!tcp) + return nullptr; SSL* client = new SSL(std::move(*tcp)); @@ -151,9 +158,6 @@ namespace ehs void SSL::Connect(const Str_8& address, const UInt_16 port) { - if (bound) - return; - TCP::Connect(address, port); OpenSSL_add_ssl_algorithms(); @@ -192,8 +196,12 @@ namespace ehs if (written <= 0) { int code = SSL_get_error(sslHdl, written); - ERR_print_errors_fp(stderr); - EHS_LOG_INT(LogType::ERR, 0, "Failed to send data with error #" + Str_8::FromNum(code) + "."); + if (code != SSL_ERROR_WANT_WRITE) + { + ERR_print_errors_fp(stderr); + EHS_LOG_INT(LogType::ERR, 0, "Failed to send data with error #" + Str_8::FromNum(code) + "."); + } + return 0; } @@ -206,8 +214,12 @@ namespace ehs if (received <= 0) { int code = SSL_get_error(sslHdl, received); - ERR_print_errors_fp(stderr); - EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data with error #" + Str_8::FromNum(code) + "."); + if (code != SSL_ERROR_WANT_READ) + { + ERR_print_errors_fp(stderr); + EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data with error #" + Str_8::FromNum(code) + "."); + } + return 0; } diff --git a/src/io/socket/TCP_BSD.cpp b/src/io/socket/TCP_BSD.cpp index 4bfed25..2c43c15 100644 --- a/src/io/socket/TCP_BSD.cpp +++ b/src/io/socket/TCP_BSD.cpp @@ -385,6 +385,24 @@ namespace ehs return result; } + void TCP::SetReuse(const bool &value) + { + if (!IsValid()) + { + EHS_LOG_INT(LogType::WARN, 1, "Attempted to set address and port reuse while socket is not initialized."); + return; + } + + const int result = (int)value; + if (setsockopt(hdl, SOL_SOCKET, SO_REUSEPORT, &result, sizeof(int)) == -1) + { + EHS_LOG_INT(LogType::ERR, 2, "Failed to set address and port reuse with error #" + Str_8::FromNum(errno) + "."); + return; + } + + EHS_LOG_SUCCESS(); + } + bool TCP::IsValid() const { return hdl != EHS_INVALID_SOCKET;