/*
   This file is part of Fennix Kernel.

   Fennix Kernel is free software: you can redistribute it and/or
   modify it under the terms of the GNU General Public License as
   published by the Free Software Foundation, either version 3 of
   the License, or (at your option) any later version.

   Fennix Kernel is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with Fennix Kernel. If not, see <https://www.gnu.org/licenses/>.
*/

#include <net/arp.hpp>
#include <debug.h>
#include <smart_ptr.hpp>

#include "../kernel.h"

/* conversion from ‘uint48_t’ {aka ‘long unsigned int’} to ‘long unsigned int:48’ may change value */
#pragma GCC diagnostic ignored "-Wconversion"

namespace NetworkARP
{
	DiscoveredAddress *ARP::ManageDiscoveredAddresses(DAType Type, InternetProtocol IP, MediaAccessControl MAC)
	{
		switch (Type)
		{
		case DA_ADD:
		{
			DiscoveredAddress *tmp = new DiscoveredAddress;
			tmp->IP = IP;
			tmp->MAC = MAC;
			DiscoveredAddresses.push_back(tmp);
			netdbg("Added %s to discovered addresses", IP.v4.ToStringLittleEndian());
			return tmp;
		}
		case DA_DEL:
		{
			forItr(itr, DiscoveredAddresses)
			{
				if ((*itr)->IP.v4 == IP.v4)
				{
					DiscoveredAddress *tmp = *itr;
					netdbg("Removed %s from discovered addresses", IP.v4.ToStringLittleEndian());
					delete tmp, tmp = nullptr;
					DiscoveredAddresses.erase(itr);
					break;
				}
			}

			return nullptr;
		}
		case DA_SEARCH:
		{
			for (size_t i = 0; i < DiscoveredAddresses.size(); i++)
			{
				if (DiscoveredAddresses[i]->IP.v4 == IP.v4)
				{
					netdbg("Found %s in discovered addresses", IP.v4.ToStringLittleEndian());
					return DiscoveredAddresses[i];
				}
			}
			return nullptr;
		}
		case DA_UPDATE:
		{
			for (size_t i = 0; i < DiscoveredAddresses.size(); i++)
			{
				if (DiscoveredAddresses[i]->IP.v4 == IP.v4)
				{
					DiscoveredAddresses[i]->MAC = MAC;
					netdbg("Updated %s in discovered addresses", IP.v4.ToStringLittleEndian());
					return DiscoveredAddresses[i];
				}
			}
			return nullptr;
		}
		default:
		{
			return nullptr;
		}
		}
		return nullptr;
	}

	ARP::ARP(NetworkEthernet::Ethernet *Ethernet) : NetworkEthernet::EthernetEvents(NetworkEthernet::TYPE_ARP)
	{
		debug("ARP interface %#lx created.", this);
		this->Ethernet = Ethernet;
	}

	ARP::~ARP()
	{
		debug("ARP interface %#lx destroyed.", this);
	}

	MediaAccessControl InvalidMAC;
	InternetProtocol InvalidIP;
	DiscoveredAddress InvalidRet = {.MAC = InvalidMAC, .IP = InvalidIP};

	DiscoveredAddress *ARP::Search(InternetProtocol TargetIP)
	{
		DiscoveredAddress *ret = ManageDiscoveredAddresses(DA_SEARCH, TargetIP, MediaAccessControl());
		if (ret)
			return ret;
		warn("No address found for %s", TargetIP.v4.ToStringLittleEndian());
		return &InvalidRet;
	}

	DiscoveredAddress *ARP::Update(InternetProtocol TargetIP, MediaAccessControl TargetMAC)
	{
		DiscoveredAddress *ret = ManageDiscoveredAddresses(DA_UPDATE, TargetIP, TargetMAC);
		if (ret)
			return ret;
		warn("No address found for %s", TargetIP.v4.ToStringLittleEndian());
		return &InvalidRet;
	}

	uint48_t ARP::Resolve(InternetProtocol IP)
	{
		netdbg("Resolving %s", IP.v4.ToStringLittleEndian());
		if (IP.v4 == 0xFFFFFFFF)
			return 0xFFFFFFFFFFFF;

		/* If we can't find the MAC, return the default value which is 0xFFFFFFFFFFFF. */
		uint48_t ret = this->Search(IP)->MAC.ToHex();
		netdbg("Resolved %s to %lx", IP.v4.ToStringLittleEndian(), ret);

		if (ret == 0xFFFFFFFFFFFF)
		{
			netdbg("Sending request");
			/* If we can't find the MAC, send a request.
			   Because we are going to send this over the network, we need to byteswap first.
			   This is actually very confusing for me. */
			ARPHeader *Header = new ARPHeader;
			Header->HardwareType = b16(ARPHardwareType::HTYPE_ETHERNET);
			Header->ProtocolType = b16(NetworkEthernet::FrameType::TYPE_IPV4);
			Header->HardwareSize = b8(6);
			Header->ProtocolSize = b8(4);
			Header->Operation = b16(ARPOperation::REQUEST);
			Header->SenderMAC = b48(Ethernet->GetInterface()->MAC.ToHex());
			Header->SenderIP = b32(Ethernet->GetInterface()->IP.v4.ToHex());
			Header->TargetMAC = b48(0xFFFFFFFFFFFF);
			Header->TargetIP = b32(IP.v4.ToHex());
			netdbg("SIP: %s; TIP: %s", InternetProtocol().v4.FromHex(Header->SenderIP).ToStringLittleEndian(),
				   InternetProtocol().v4.FromHex(Header->TargetIP).ToStringLittleEndian());
			/* Send the request to the broadcast MAC address. */
			Ethernet->Send({.Address = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}, NetworkEthernet::FrameType::TYPE_ARP, (uint8_t *)Header, sizeof(ARPHeader));
			delete Header, Header = nullptr;
		}

		int RequestTimeout = 20;
		debug("Waiting for response");
		while (ret == 0xFFFFFFFFFFFF)
		{
			ret = this->Search(IP)->MAC.ToHex();
			if (--RequestTimeout == 0)
			{
				warn("Request timeout.");
				return 0;
			}
			netdbg("Still waiting...");
			TaskManager->Sleep(1000);
		}

		return ret;
	}

	void ARP::Broadcast(InternetProtocol IP)
	{
		netdbg("Sending broadcast");
		uint48_t ResolvedMAC = this->Resolve(IP);
		ARPHeader *Header = new ARPHeader;
		Header->HardwareType = b16(ARPHardwareType::HTYPE_ETHERNET);
		Header->ProtocolType = b16(NetworkEthernet::FrameType::TYPE_IPV4);
		Header->HardwareSize = b8(0x6);
		Header->ProtocolSize = b8(0x4);
		Header->Operation = b16(ARPOperation::REQUEST);
		Header->SenderMAC = b48(Ethernet->GetInterface()->MAC.ToHex());
		Header->SenderIP = b32(Ethernet->GetInterface()->IP.v4.ToHex());
		Header->TargetMAC = b48(ResolvedMAC);
		Header->TargetIP = b32(IP.v4.ToHex());
		Ethernet->Send(MediaAccessControl().FromHex(ResolvedMAC), NetworkEthernet::FrameType::TYPE_ARP, (uint8_t *)Header, sizeof(ARPHeader));
		delete Header, Header = nullptr;
	}

	bool ARP::OnEthernetPacketReceived(uint8_t *Data, size_t Length)
	{
		UNUSED(Length);
		netdbg("Received packet");
		ARPHeader *Header = (ARPHeader *)Data;

		InternetProtocol SenderIPv4;
		SenderIPv4.v4 = SenderIPv4.v4.FromHex(b32(Header->SenderIP));

		/* We only support Ethernet and IPv4. */
		if (b16(Header->HardwareType) != ARPHardwareType::HTYPE_ETHERNET || b16(Header->ProtocolType) != NetworkEthernet::FrameType::TYPE_IPV4)
		{
			warn("Invalid hardware/protocol type (%d/%d)", b16(Header->HardwareType), b16(Header->ProtocolType));
			return false;
		}

		InternetProtocol TmpIPStruct;
		InternetProtocol().v4.FromHex(b32(Header->SenderIP));

		if (ManageDiscoveredAddresses(DA_SEARCH, TmpIPStruct, MediaAccessControl().FromHex(b48(Header->SenderMAC))) == nullptr)
		{
			netdbg("Discovered new address %s", SenderIPv4.v4.ToStringLittleEndian());
			TmpIPStruct.v4.FromHex(b32(Header->SenderIP));
			ManageDiscoveredAddresses(DA_ADD, TmpIPStruct, MediaAccessControl().FromHex(b48(Header->SenderMAC)));
		}
		else
		{
			netdbg("Updated address %s", SenderIPv4.v4.ToStringLittleEndian());
			TmpIPStruct.v4.FromHex(b32(Header->SenderIP));
			ManageDiscoveredAddresses(DA_UPDATE, TmpIPStruct, MediaAccessControl().FromHex(b48(Header->SenderMAC)));
		}

		switch (b16(Header->Operation))
		{
		case ARPOperation::REQUEST:
			netdbg("Received request from %s", SenderIPv4.v4.ToStringLittleEndian());
			/* We need to byteswap before we send it back. */
			Header->TargetMAC = b48(Header->SenderMAC);
			Header->TargetIP = b32(Header->SenderIP);
			Header->SenderMAC = b48(Ethernet->GetInterface()->MAC.ToHex());
			Header->SenderIP = b32(Ethernet->GetInterface()->IP.v4.ToHex());
			Header->Operation = b16(ARPOperation::REPLY);
			Ethernet->Send(MediaAccessControl().FromHex(Header->TargetMAC), NetworkEthernet::FrameType::TYPE_ARP, (uint8_t *)Header, sizeof(ARPHeader));
			netdbg("Sent request for %s", SenderIPv4.v4.ToStringLittleEndian());
			break;
		case ARPOperation::REPLY:
			fixme("Received reply from %s", SenderIPv4.v4.ToStringLittleEndian());
			break;
		default:
			warn("Invalid operation (%d)", b16(Header->Operation));
			break;
		}
		return false;
	}
}