EHS/src/io/socket/SSL.cpp

237 lines
3.9 KiB
C++

#include "ehs/io/socket/SSL.h"
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/pem.h>
namespace ehs
{
SSL::~SSL()
{
if (!IsValid())
return;
if (sslHdl)
{
if (connection)
SSL_shutdown(sslHdl);
SSL_free(sslHdl);
}
if (ctx)
SSL_CTX_free(ctx);
}
SSL::SSL()
: ctx(nullptr), sslHdl(nullptr)
{
}
SSL::SSL(const AddrType type)
: TCP(type), ctx(nullptr), sslHdl(nullptr)
{
SSL::Initialize();
}
SSL::SSL(TCP&& tcp) noexcept
: TCP(std::move(tcp)), ctx(nullptr), sslHdl(nullptr)
{
}
SSL::SSL(const TCP& tcp)
: TCP(tcp), ctx(nullptr), sslHdl(nullptr)
{
}
SSL::SSL(const SSL& ssl)
: TCP(ssl), ctx(nullptr), sslHdl(nullptr)
{
}
SSL& SSL::operator=(const SSL& ssl)
{
if (this == &ssl)
return *this;
TCP::operator=(ssl);
ctx = nullptr;
sslHdl = nullptr;
return *this;
}
void SSL::Initialize()
{
TCP::Initialize();
if (IsValid())
return;
SSL_library_init();
}
void SSL::Release()
{
TCP::Release();
if (!IsValid())
return;
if (sslHdl)
{
if (connection)
SSL_shutdown(sslHdl);
SSL_free(sslHdl);
sslHdl = nullptr;
}
if (ctx)
{
SSL_CTX_free(ctx);
ctx = nullptr;
}
}
void SSL::Bind(const Str_8& address, unsigned short port)
{
if (bound)
return;
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
ctx = SSL_CTX_new(SSLv23_server_method());
sslHdl = SSL_new(ctx);
SSL_set_fd(sslHdl, hdl);
TCP::Bind(address, port);
}
SSL* SSL::Accept()
{
if (!bound)
return nullptr;
TCP* tcp = TCP::Accept();
SSL* client = new SSL(std::move(*tcp));
delete tcp;
client->ctx = nullptr;
client->sslHdl = SSL_new(ctx);
SSL_set_fd(client->sslHdl, client->hdl);
int err = SSL_accept(client->sslHdl);
if (!err)
{
EHS_LOG_INT(LogType::ERR, 0, "Failed SSL handshake with error #" + Str_8::FromNum(SSL_get_error(client->sslHdl, err)) + ".");
return {};
}
return client;
}
void SSL::Connect(const Str_8& address, const UInt_16 port)
{
if (bound)
return;
TCP::Connect(address, port);
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
ctx = SSL_CTX_new(SSLv23_client_method());
sslHdl = SSL_new(ctx);
SSL_set_fd(sslHdl, hdl);
SSL_connect(sslHdl);
}
UInt_64 SSL::Send(const Byte* const buffer, const UInt_32 size)
{
int written = SSL_write(sslHdl, buffer, (int)size);
if (written <= 0)
{
int code = SSL_get_error(sslHdl, written);
ERR_print_errors_fp(stderr);
EHS_LOG_INT(LogType::ERR, 0, "Failed to send data with error #" + Str_8::FromNum(code) + ".");
return 0;
}
return written;
}
UInt_64 SSL::Receive(Byte* const buffer, const UInt_32 size)
{
int received = SSL_read(sslHdl, buffer, (int)size);
if (received <= 0)
{
int code = SSL_get_error(sslHdl, received);
ERR_print_errors_fp(stderr);
EHS_LOG_INT(LogType::ERR, 0, "Failed to receive data with error #" + Str_8::FromNum(code) + ".");
return 0;
}
return received;
}
void SSL::UseCertificate(const Byte* data, const UInt_64 size)
{
X509 *cert = d2i_X509(nullptr, &data, (long)size);
if (!cert)
{
EHS_LOG_INT(LogType::ERR, 0, "Invalid certificate.");
return;
}
if (SSL_CTX_use_certificate(ctx, cert) != 1)
{
EHS_LOG_INT(LogType::ERR, 1, "Failed to use certificate.");
return;
}
X509_free(cert);
}
void SSL::UsePrivateKey(const Byte* data, const UInt_64 size)
{
EVP_PKEY *key = d2i_PrivateKey(EVP_PKEY_RSA, nullptr, &data, (long)size);
if (!key)
{
EHS_LOG_INT(LogType::ERR, 0, "Invalid private key.");
return;
}
if (SSL_CTX_use_PrivateKey(ctx, key) != 1)
{
EHS_LOG_INT(LogType::ERR, 1, "Failed to use private key.");
return;
}
EVP_PKEY_free(key);
}
bool SSL::IsValid()
{
return TCP::IsValid() && sslHdl;
}
}