#include "ehs/Task.h"

namespace ehs
{
	UInt_32 TaskThread(void* args)
	{
		if (!args)
			return 1;

		Serializer<UInt_64>* data = (Serializer<UInt_64>*)args;
		Semaphore* available = data->Read<Semaphore*>();
		Semaphore* done = data->Read<Semaphore*>();
		Serializer<UInt_64>** cbArgs = data->Read<Serializer<UInt_64>**>();
		TaskCb* callback = data->Read<TaskCb*>();

		while (true)
		{
			if (!available->Wait(EHS_INFINITE))
			{
				done->Signal(1);
				return 1;
			}

			if (!*callback)
			{
				done->Signal(1);
				return 0;
			}

			(*callback)(*cbArgs);

			done->Signal(1);
		}
	}

	Task::~Task()
	{
		if (!IsValid())
			return;

		if (!working)
		{
			*callback = nullptr;
			available->Signal(1);
		}
		else
		{
			done->Wait(1000);
			*callback = nullptr;
			available->Signal(1);
		}

		thread.Join();
		delete available;
		delete done;
		delete *cbArgs;
		delete cbArgs;
		delete callback;
	}

	Task::Task()
		: working(false), available(nullptr), done(nullptr), cbArgs(nullptr), callback(nullptr), threadArgs(nullptr)
	{
		Initialize();
	}

	Task::Task(Task&& task) noexcept
		: working(task.working), available(task.available), done(task.done), cbArgs(task.cbArgs),
		callback(task.callback), threadArgs(task.threadArgs), thread(std::move(task.thread))
	{
		task.working = false;
		task.available = nullptr;
		task.done = nullptr;
		task.cbArgs = nullptr;
		task.callback = nullptr;
		task.threadArgs = nullptr;
	}

	Task::Task(const Task& task)
		: working(false), available(nullptr), done(nullptr), cbArgs(nullptr), callback(nullptr), threadArgs(nullptr)
	{
		Initialize();
	}

	Task& Task::operator=(Task&& task) noexcept
	{
		if (this == &task)
			return *this;

		Release();

		working = task.working;
		available = task.available;
		done = task.done;
		cbArgs = task.cbArgs;
		callback = task.callback;
		threadArgs = task.threadArgs;
		thread = std::move(task.thread);

		task.working = false;
		task.available = nullptr;
		task.done = nullptr;
		task.cbArgs = nullptr;
		task.callback = nullptr;
		task.threadArgs = nullptr;
		task.thread = {};

		return *this;
	}

	Task& Task::operator=(const Task& task)
	{
		if (this == &task)
			return *this;

		Release();

		Initialize();

		return *this;
	}

	void Task::Revalidate()
	{
		if (!IsValid())
			return;

		Release();
		Initialize();
	}

	void Task::Initialize()
	{
		if (IsValid())
			return;

		working = false;
		available = new Semaphore(0);
		done = new Semaphore(0);
		cbArgs = new Serializer<ehs::UInt_64>*(new Serializer<UInt_64>());
		callback = new TaskCb(nullptr);

		threadArgs = new Serializer<UInt_64>(Endianness::LE);
		threadArgs->Write(available);
		threadArgs->Write(done);
		threadArgs->Write(cbArgs);
		threadArgs->Write(callback);
		threadArgs->SetOffset(0);

		thread.Start(TaskThread, threadArgs);
	}

	void Task::Release()
	{
		if (!IsValid())
			return;

		if (!working)
		{
			*callback = nullptr;
			available->Signal(1);
		}
		else
		{
			done->Wait(1000);
			*callback = nullptr;
			available->Signal(1);
		}

		working = false;
		thread.Join();
		delete available;
		available = nullptr;
		delete done;
		done = nullptr;
		delete *cbArgs;
		delete cbArgs;
		cbArgs = nullptr;
		delete callback;
		callback = nullptr;
	}

	bool Task::IsWorking() const
	{
		return working;
	}

	void Task::GiveWork(Serializer<UInt_64> args, TaskCb cb)
	{
		if (working)
		{
			EHS_LOG_INT("Warning", 0, "Attempted to give work while task is still working.");
			return;
		}

		**cbArgs = std::move(args);
		*callback = cb;
		working = true;

		available->Signal(1);
	}

	void Task::WaitUntilDone()
	{
		if (!working)
			return;

		done->Wait(EHS_INFINITE);
		working = false;
	}

	bool Task::IsValid() const
	{
		return thread.IsValid() && available && available->IsValid() && done && done->IsValid() && cbArgs && callback;
	}
}