#include "ehs/io/socket/SSL.h" #include #include #include #include #include #include namespace ehs { SSL::~SSL() { if (!IsValid()) return; EVP_PKEY_free(pkey); X509_free(cert); if (sslHdl) { if (connection) SSL_shutdown(sslHdl); SSL_free(sslHdl); } if (ctx) SSL_CTX_free(ctx); } SSL::SSL() : server(false), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr) { } SSL::SSL(const IP &type, const bool &server) : TCP(type), server(server), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr) { SSL::Initialize(); } SSL::SSL(TCP&& tcp) noexcept : TCP(std::move(tcp)), server(false), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr) { } SSL::SSL(const TCP& tcp) : TCP(tcp), server(false), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr) { } SSL::SSL(const SSL& ssl) : TCP(ssl), server(ssl.server), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr) { } SSL& SSL::operator=(const SSL& ssl) { if (this == &ssl) return *this; TCP::operator=(ssl); server = ssl.server; ctx = nullptr; sslHdl = nullptr; cert = nullptr; pkey = nullptr; return *this; } void SSL::Initialize() { TCP::Initialize(); if (IsValid()) return; SSL_library_init(); OpenSSL_add_ssl_algorithms(); SSL_load_error_strings(); OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, nullptr); OSSL_PROVIDER_load(nullptr, "default"); SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); if (server) { ctx = SSL_CTX_new(TLS_server_method()); if (!ctx) { UInt_32 code = ERR_get_error(); EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr)); } SSL_CTX_set_cipher_list(ctx, "HIGH:!aNULL:!MD5"); SSL_CTX_set_ciphersuites(ctx, "TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256:" "TLS_CHACHA20_POLY1305_SHA256:ECDHE-RSA-AES128-GCM-SHA256" ); } else { ctx = SSL_CTX_new(TLS_client_method()); if (!ctx) { UInt_32 code = ERR_get_error(); EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr)); } SSL_CTX_set_default_verify_paths(ctx); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr); } } void SSL::Release() { TCP::Release(); if (!IsValid()) return; EVP_PKEY_free(pkey); pkey = nullptr; X509_free(cert); cert = nullptr; if (sslHdl) { if (connection) SSL_shutdown(sslHdl); SSL_free(sslHdl); sslHdl = nullptr; } if (ctx) { SSL_CTX_free(ctx); ctx = nullptr; } } void SSL::Listen() { sslHdl = SSL_new(ctx); SSL_set_fd(sslHdl, hdl); TCP::Listen(); EHS_LOG_SUCCESS(); } SSL* SSL::Accept() { if (!bound) return nullptr; TCP* tcp = TCP::Accept(); if (!tcp) return nullptr; SSL* client = new SSL(std::move(*tcp)); delete tcp; client->sslHdl = SSL_new(this->ctx); if (!client->sslHdl) { UInt_32 code = ERR_get_error(); EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr)); return nullptr; } SSL_set_fd(client->sslHdl, client->hdl); int err = SSL_accept(client->sslHdl); if (!err) { EHS_LOG_INT(LogType::ERR, 0, "Failed SSL handshake with error #" + Str_8::FromNum(SSL_get_error(client->sslHdl, err)) + "."); return {}; } EHS_LOG_SUCCESS(); return client; } void SSL::Connect(Str_8 address, const UInt_16 &port) { TCP::Connect(address, port); sslHdl = SSL_new(ctx); SSL_set_fd(sslHdl, hdl); SSL_set_tlsext_host_name(sslHdl, &address[0]); SInt_32 rc = SSL_connect(sslHdl); if (rc != 1) { EHS_LOG_INT(LogType::ERR, 1, "Failed to connect with error #" + Str_8::FromNum(SSL_get_error(sslHdl, rc)) + "."); return; } EHS_LOG_SUCCESS(); } UInt_64 SSL::Send(const Byte* const buffer, const UInt_32 size) { int written = SSL_write(sslHdl, buffer, (int)size); if (written <= 0) { int code = SSL_get_error(sslHdl, written); if (code != SSL_ERROR_WANT_WRITE) { ERR_print_errors_fp(stderr); EHS_LOG_INT(LogType::ERR, 0, "Failed to send data with error #" + Str_8::FromNum(code) + "."); } return 0; } return written; } UInt_64 SSL::Receive(Byte* const buffer, const UInt_32 size) { int received = SSL_read(sslHdl, buffer, (int)size); if (received <= 0) { int code = SSL_get_error(sslHdl, received); if (code != SSL_ERROR_WANT_READ) { ERR_print_errors_fp(stderr); EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data with error #" + Str_8::FromNum(code) + "."); } return 0; } return received; } void SSL::UseCertificate(const Char_8* data, const UInt_32 &size) { BIO* certBio = BIO_new_mem_buf(data, (int)size); X509 *cert = PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr); if (!SSL_CTX_use_certificate(ctx, cert)) { UInt_32 code = ERR_get_error(); EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr)); return; } BIO_free(certBio); } void SSL::UsePrivateKey(const Char_8* data, const UInt_32 &size) { BIO* keyBio = BIO_new_mem_buf(data, (int)size); EVP_PKEY* pkey = PEM_read_bio_PrivateKey(keyBio, nullptr, nullptr, nullptr); if (!SSL_CTX_use_PrivateKey(ctx, pkey)) { UInt_32 code = ERR_get_error(); EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr)); return; } BIO_free(keyBio); } bool SSL::IsValid() { return TCP::IsValid() && sslHdl; } }