#include "ehs/io/socket/SSL.h" #include #include #include #include namespace ehs { SSL::~SSL() { if (!IsValid()) return; if (sslHdl) { if (connection) SSL_shutdown(sslHdl); SSL_free(sslHdl); } if (ctx) SSL_CTX_free(ctx); } SSL::SSL() : ctx(nullptr), sslHdl(nullptr) { } SSL::SSL(const AddrType type) : TCP(type), ctx(nullptr), sslHdl(nullptr) { SSL::Initialize(); } SSL::SSL(TCP&& tcp) noexcept : TCP(std::move(tcp)), ctx(nullptr), sslHdl(nullptr) { } SSL::SSL(const TCP& tcp) : TCP(tcp), ctx(nullptr), sslHdl(nullptr) { } SSL::SSL(const SSL& ssl) : TCP(ssl), ctx(nullptr), sslHdl(nullptr) { } SSL& SSL::operator=(const SSL& ssl) { if (this == &ssl) return *this; TCP::operator=(ssl); ctx = nullptr; sslHdl = nullptr; return *this; } void SSL::Initialize() { TCP::Initialize(); if (IsValid()) return; SSL_library_init(); } void SSL::Release() { TCP::Release(); if (!IsValid()) return; if (sslHdl) { if (connection) SSL_shutdown(sslHdl); SSL_free(sslHdl); sslHdl = nullptr; } if (ctx) { SSL_CTX_free(ctx); ctx = nullptr; } } void SSL::Bind(const Str_8& address, unsigned short port) { if (bound) return; OpenSSL_add_all_algorithms(); SSL_load_error_strings(); ctx = SSL_CTX_new(SSLv23_server_method()); sslHdl = SSL_new(ctx); SSL_set_fd(sslHdl, hdl); TCP::Bind(address, port); } SSL* SSL::Accept() { if (!bound) return nullptr; TCP* tcp = TCP::Accept(); SSL* client = new SSL(std::move(*tcp)); delete tcp; client->ctx = nullptr; client->sslHdl = SSL_new(ctx); SSL_set_fd(client->sslHdl, client->hdl); int err = SSL_accept(client->sslHdl); if (!err) { EHS_LOG_INT(LogType::ERR, 0, "Failed SSL handshake with error #" + Str_8::FromNum(SSL_get_error(client->sslHdl, err)) + "."); return {}; } return client; } void SSL::Connect(const Str_8& address, const UInt_16 port) { if (bound) return; TCP::Connect(address, port); OpenSSL_add_all_algorithms(); SSL_load_error_strings(); ctx = SSL_CTX_new(SSLv23_client_method()); sslHdl = SSL_new(ctx); SSL_set_fd(sslHdl, hdl); SSL_connect(sslHdl); } UInt_64 SSL::Send(const Byte* const buffer, const UInt_32 size) { int written = SSL_write(sslHdl, buffer, (int)size); if (written <= 0) { int code = SSL_get_error(sslHdl, written); ERR_print_errors_fp(stderr); EHS_LOG_INT(LogType::ERR, 0, "Failed to send data with error #" + Str_8::FromNum(code) + "."); return 0; } return written; } UInt_64 SSL::Receive(Byte* const buffer, const UInt_32 size) { int received = SSL_read(sslHdl, buffer, (int)size); if (received <= 0) { int code = SSL_get_error(sslHdl, received); ERR_print_errors_fp(stderr); EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data with error #" + Str_8::FromNum(code) + "."); return 0; } return received; } void SSL::UseCertificate(const Byte* data, const UInt_64 size) { X509 *cert = d2i_X509(nullptr, &data, (long)size); if (!cert) { EHS_LOG_INT(LogType::ERR, 0, "Invalid certificate."); return; } if (SSL_CTX_use_certificate(ctx, cert) != 1) { EHS_LOG_INT(LogType::ERR, 1, "Failed to use certificate."); return; } X509_free(cert); } void SSL::UsePrivateKey(const Byte* data, const UInt_64 size) { EVP_PKEY *key = d2i_PrivateKey(EVP_PKEY_RSA, nullptr, &data, (long)size); if (!key) { EHS_LOG_INT(LogType::ERR, 0, "Invalid private key."); return; } if (SSL_CTX_use_PrivateKey(ctx, key) != 1) { EHS_LOG_INT(LogType::ERR, 1, "Failed to use private key."); return; } EVP_PKEY_free(key); } bool SSL::IsValid() { return TCP::IsValid() && sslHdl; } }