#include <ehs/io/Console.h>
#include <ehs/io/socket/ICMP.h>
#include <ehs/io/socket/DNS.h>
#include <ehs/Serializer.h>
#include <ehs/Math.h>
#include <ehs/HRNG.h>
#include <ehs/system/Thread.h>
#include <ehs/system/CPU.h>
#include <ehs/Str.h>

using namespace ehs;

union IPv4
{
	UInt_32 address;
	UInt_8 octets[4];
};

struct Result
{
	IPv4 address;
	bool response;
	UInt_8 padding[3];
};

struct PingData
{
	Result* results;
	UInt_64 count;
};

UInt_64 resultsPerThread = CPU::GetCacheLineSize() / sizeof(Result);

Str_8 AddressToStr(const IPv4 &address)
{
	Str_8 result;
	result += address.octets[0];
	result += ".";
	result += address.octets[1];
	result += ".";
	result += address.octets[2];
	result += ".";
	result += address.octets[3];
	return result;
}

UInt_32 Ping(void *data)
{
	PingData* pingData = (PingData*)data;

	Serializer<UInt_64> payload(Endianness::LE);
	payload.WriteStr("Hello world!");

	ICMP icmp(IP::V4);
	icmp.SetReceiveTimeout(1);

	for (UInt_32 i = 0; i < pingData->count; ++i)
	{
		const UInt_16 id = HRNG::Generate_u16();

		icmp.SendEchoRequest(AddressToStr(pingData->results[i].address), {id, (UInt_16)(i + 1)}, payload, payload.Size());

		Str_8 inAddr;
		ICMP_Header header = {};

		Serializer<UInt_64> inPayload = Serializer<UInt_64>(Endianness::LE);
		UInt_64 received = icmp.Receive(inAddr, header, inPayload);
		if (!received)
			continue;

		ICMP_EchoRequest er = inPayload.Read<ICMP_EchoRequest>();

		if (er.id != id)
			continue;

		pingData->results[i].response = true;
	}

	return 0;
}

IPv4 StrToAddress(const Str_8 &address)
{
	const Vector<Str_8> octets = address.Split(".");
	if (octets.Size() != 4)
		return {};

	IPv4 result;
	result.octets[0] = octets[0].ToDecimal<UInt_8>();
	result.octets[1] = octets[1].ToDecimal<UInt_8>();
	result.octets[2] = octets[2].ToDecimal<UInt_8>();
	result.octets[3] = octets[3].ToDecimal<UInt_8>();

	return result;
}

UInt_8 GetClass(const IPv4 &address)
{
	if (address.octets[0] >= 1 && address.octets[0] <= 126)
		return 0;
	else if (address.octets[0] >= 128 && address.octets[0] <= 191)
		return 1;
	else if (address.octets[0] >= 192 && address.octets[0] <= 223)
		return 2;
	else
		return EHS_UINT_8_MAX;
}

IPv4 GetSubnetMask(const UInt_8 &ipClass)
{
	switch (ipClass)
	{
		case 0:
			return {0x000000FF};
		case 1:
			return {0x0000FFFF};
		case 2:
			return {0x00FFFFFF};
		default:
			return {0x00000000};
	}
}

int main()
{
	Initialize("Device Scanner", "Release", {1, 0, 0});
	Log::EnableImmediateMode(true);

	Console::Write_8("Default Gateway: ", false);
	Str_8 gateway = Console::Read_8();
	IPv4 octets = StrToAddress(gateway);
	UInt_8 ipClass = GetClass(octets);
	IPv4 subnetMask = GetSubnetMask(ipClass);

	IPv4 networkAddress = {octets.address & subnetMask.address};
	IPv4 invertedMask = {~subnetMask.address};
	IPv4 broadcastAddress = {octets.address | invertedMask.address};

	IPv4 firstAddress = {networkAddress.address};
	firstAddress.octets[3] += 1;

	IPv4 lastAddress = {broadcastAddress.address};
	lastAddress.octets[3] -= 1;

	Console::Write_8("\nIP Class: " + Str_8::FromNum(ipClass));
	Console::Write_8("Subnet Mask: " + AddressToStr(subnetMask));
	Console::Write_8("Network Address: " + AddressToStr(networkAddress));
	Console::Write_8("Inverted Mask: " + AddressToStr(invertedMask));
	Console::Write_8("Broadcast Address: " + AddressToStr(broadcastAddress));
	Console::Write_8("First Address: " + AddressToStr(firstAddress));
	Console::Write_8("Last Address: " + AddressToStr(lastAddress) + "\n");

	Array<Result> results(lastAddress.octets[3] - firstAddress.octets[3]);
	for (UInt_64 i = 0; i < results.Size(); ++i)
	{
		results[i] = {};
		results[i].address = networkAddress;
		results[i].address.octets[3] = firstAddress.octets[3] + i;
		results[i].response = false;
	}

	Array<Thread> threads((UInt_64)Math::Ceil((double)results.Size() / (double)resultsPerThread));
	Array<PingData> pings(threads.Size());

	for (UInt_64 i = 0; i < threads.Size(); ++i)
	{
		UInt_64 offset = i * resultsPerThread;
		if (offset + resultsPerThread > results.Size())
		{
			UInt_64 delta = results.Size() - offset;

			pings[i].results = &results[offset];
			pings[i].count = delta;
		}
		else
		{
			pings[i].results = &results[offset];
			pings[i].count = resultsPerThread;
		}
	}

	for (UInt_64 i = 0; i < threads.Size(); ++i)
		threads[i].Start(Ping, &pings[i]);

	for (UInt_64 i = 0; i < threads.Size(); ++i)
		threads[i].Join();

	Console::Write_8("Active Devices:");

	for (UInt_64 i = 0; i < results.Size(); ++i)
	{
		if (!results[i].response)
			continue;

		Console::Write_8(AddressToStr(results[i].address));
	}

	Uninitialize();

	return 0;
}