Files
EHS/src/io/socket/SSL.cpp

300 lines
5.7 KiB
C++

#include "ehs/io/socket/SSL.h"
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/opensslv.h>
#include <openssl/provider.h>
namespace ehs
{
SSL::~SSL()
{
if (!IsValid())
return;
EVP_PKEY_free(pkey);
X509_free(cert);
if (sslHdl)
{
if (connection)
SSL_shutdown(sslHdl);
SSL_free(sslHdl);
}
if (ctx)
SSL_CTX_free(ctx);
}
SSL::SSL()
: server(false), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr)
{
}
SSL::SSL(const IP &type, const bool &server)
: TCP(type), server(server), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr)
{
SSL::Initialize();
}
SSL::SSL(TCP&& tcp) noexcept
: TCP(std::move(tcp)), server(false), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr)
{
}
SSL::SSL(const TCP& tcp)
: TCP(tcp), server(false), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr)
{
}
SSL::SSL(const SSL& ssl)
: TCP(ssl), server(ssl.server), ctx(nullptr), sslHdl(nullptr), cert(nullptr), pkey(nullptr)
{
}
SSL& SSL::operator=(const SSL& ssl)
{
if (this == &ssl)
return *this;
TCP::operator=(ssl);
server = ssl.server;
ctx = nullptr;
sslHdl = nullptr;
cert = nullptr;
pkey = nullptr;
return *this;
}
void SSL::Initialize()
{
TCP::Initialize();
if (IsValid())
return;
SSL_library_init();
OpenSSL_add_ssl_algorithms();
SSL_load_error_strings();
OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, nullptr);
OSSL_PROVIDER_load(nullptr, "default");
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
if (server)
{
ctx = SSL_CTX_new(TLS_server_method());
if (!ctx)
{
UInt_32 code = ERR_get_error();
EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr));
}
SSL_CTX_set_cipher_list(ctx, "HIGH:!aNULL:!MD5");
SSL_CTX_set_ciphersuites(ctx,
"TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256:"
"TLS_CHACHA20_POLY1305_SHA256:ECDHE-RSA-AES128-GCM-SHA256"
);
}
else
{
ctx = SSL_CTX_new(TLS_client_method());
if (!ctx)
{
UInt_32 code = ERR_get_error();
EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr));
}
SSL_CTX_set_default_verify_paths(ctx);
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr);
}
}
void SSL::Release()
{
TCP::Release();
if (!IsValid())
return;
EVP_PKEY_free(pkey);
pkey = nullptr;
X509_free(cert);
cert = nullptr;
if (sslHdl)
{
if (connection)
SSL_shutdown(sslHdl);
SSL_free(sslHdl);
sslHdl = nullptr;
}
if (ctx)
{
SSL_CTX_free(ctx);
ctx = nullptr;
}
}
void SSL::Listen()
{
sslHdl = SSL_new(ctx);
SSL_set_fd(sslHdl, hdl);
TCP::Listen();
EHS_LOG_SUCCESS();
}
SSL* SSL::Accept()
{
if (!bound)
return nullptr;
TCP* tcp = TCP::Accept();
if (!tcp)
return nullptr;
SSL* client = new SSL(std::move(*tcp));
delete tcp;
client->sslHdl = SSL_new(this->ctx);
if (!client->sslHdl)
{
UInt_32 code = ERR_get_error();
EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr));
return nullptr;
}
SSL_set_fd(client->sslHdl, client->hdl);
int err = SSL_accept(client->sslHdl);
if (!err)
{
EHS_LOG_INT(LogType::ERR, 0, "Failed SSL handshake with error #" + Str_8::FromNum(SSL_get_error(client->sslHdl, err)) + ".");
return {};
}
EHS_LOG_SUCCESS();
return client;
}
void SSL::Connect(Str_8 address, const UInt_16 &port)
{
TCP::Connect(address, port);
sslHdl = SSL_new(ctx);
SSL_set_fd(sslHdl, hdl);
SSL_set_tlsext_host_name(sslHdl, &address[0]);
SInt_32 rc = SSL_connect(sslHdl);
if (rc != 1)
{
EHS_LOG_INT(LogType::ERR, 1, "Failed to connect with error #" + Str_8::FromNum(SSL_get_error(sslHdl, rc)) + ".");
return;
}
EHS_LOG_SUCCESS();
}
UInt_64 SSL::Send(const Byte* const buffer, const UInt_32 size)
{
int written = SSL_write(sslHdl, buffer, (int)size);
if (written <= 0)
{
int code = SSL_get_error(sslHdl, written);
if (code != SSL_ERROR_WANT_WRITE)
{
ERR_print_errors_fp(stderr);
EHS_LOG_INT(LogType::ERR, 0, "Failed to send data with error #" + Str_8::FromNum(code) + ".");
}
return 0;
}
return written;
}
UInt_64 SSL::Receive(Byte* const buffer, const UInt_32 size)
{
int received = SSL_read(sslHdl, buffer, (int)size);
if (received <= 0)
{
int code = SSL_get_error(sslHdl, received);
if (code != SSL_ERROR_WANT_READ)
{
ERR_print_errors_fp(stderr);
EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data with error #" + Str_8::FromNum(code) + ".");
}
return 0;
}
return received;
}
void SSL::UseCertificate(const Char_8* data, const UInt_32 &size)
{
BIO* certBio = BIO_new_mem_buf(data, (int)size);
X509 *cert = PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr);
if (!SSL_CTX_use_certificate(ctx, cert))
{
UInt_32 code = ERR_get_error();
EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr));
return;
}
BIO_free(certBio);
}
void SSL::UsePrivateKey(const Char_8* data, const UInt_32 &size)
{
BIO* keyBio = BIO_new_mem_buf(data, (int)size);
EVP_PKEY* pkey = PEM_read_bio_PrivateKey(keyBio, nullptr, nullptr, nullptr);
if (!SSL_CTX_use_PrivateKey(ctx, pkey))
{
UInt_32 code = ERR_get_error();
EHS_LOG_INT(LogType::ERR, 0, ERR_error_string(code, nullptr));
return;
}
BIO_free(keyBio);
}
bool SSL::IsValid()
{
return TCP::IsValid() && sslHdl;
}
}