//
// nono
// Copyright (C) 2024 nono project
// Licensed under nono-license.txt
//

//
// VirtIO ネットワーク
//

#include "virtio_net.h"
#include "virtio_def.h"
#include "config.h"
#include "ethernet.h"
#include "event.h"
#include "macaddr.h"
#include "memorystream.h"
#include "monitor.h"
#include "scheduler.h"

// デバイス構成レイアウト
class VirtIONetConfigWriter
{
 public:
	std::array<u8, 6> mac {};
	le16 status {};
	le16 max_virtqueue_pairs {};
	le16 mtu {};

 public:
	void WriteTo(uint8 *dst) const;
};

// コンストラクタ
VirtIONetDevice::VirtIONetDevice(uint slot_)
	: inherited(OBJ_VIRTIO_NET, slot_)
{
	// 短縮形
	AddAlias("VNet");
	AddAlias("Ethernet");

	device_id = VirtIO::DEVICE_ID_NETWORK;
	vqueues.emplace_back(this, 0, "ReceiveQ1", 8);
	vqueues.emplace_back(this, 1, "TransmitQ1", 16);

	// 割り込み名
	strlcpy(intrname, "VIONet", sizeof(intrname));
	// 完了通知メッセージ
	msgid = MessageID::VIRTIO_NET_DONE;

	// モニタの行数。
	int vqlines = 0;
	for (const auto& q : vqueues) {
		vqlines += 7 + q.num_max;
	}

	monitor = gMonitorManager->Regist(ID_MONITOR_VIRTIO_NET, this);
	monitor->SetCallback(&VirtIONetDevice::MonitorScreen);
	monitor->SetSize(MONITOR_WIDTH, 3 + vqlines);
}

// デストラクタ
VirtIONetDevice::~VirtIONetDevice()
{
	// EthernetDevice と同じ…。
	if ((bool)hostnet) {
		hostnet->ResetRxCallback();
	}
}

// 動的なコンストラクション
bool
VirtIONetDevice::Create()
{
	if (inherited::Create() == false) {
		return false;
	}

	// EthernetDevice とほぼ同じ…。
	try {
		hostnet.reset(new HostNetDevice(this, 0, "VirtIO Network"));
	} catch (...) { }
	if ((bool)hostnet == false) {
		warnx("Failed to initialize HostNetDevice at %s", __method__);
		return false;
	}

	auto func = ToDeviceCallback(&VirtIONetDevice::HostRxCallback);
	hostnet->SetRxCallback(func, 0);

	return true;
}

// 初期化
bool
VirtIONetDevice::Init()
{
	if (inherited::Init() == false) {
		return false;
	}

	// MAC アドレスを取得。
	if (EthernetDevice::GetConfigMacAddr(0, &macaddr, false) == false) {
		// エラーメッセージ表示済み。
		return false;
	}
	// 表示用。
	macaddr_str = macaddr.ToString(':');

	// DEVICE_FEATURES と構成レイアウトを用意。
	VirtIONetConfigWriter cfg;
	SetDeviceFeatures(VIRTIO_NET_F_CSUM);
	SetDeviceFeatures(VIRTIO_NET_F_GUEST_CSUM);
	SetDeviceFeatures(VIRTIO_NET_F_MTU), macaddr.ExportTo(&cfg.mac[0]);
	SetDeviceFeatures(VIRTIO_NET_F_MAC), cfg.mtu = 1500;
	SetDeviceFeatures(VIRTIO_NET_F_MRG_RXBUF);
	cfg.WriteTo(&device_config[0]);

	// ホストからの受信コールバックを登録
	scheduler->ConnectMessage(MessageID::HOSTNET_RX(0), this,
		ToMessageCallback(&VirtIONetDevice::RxMessage));

	auto evman = GetEventManager();
	event = evman->Regist(this,
		ToEventCallback(&VirtIONetDevice::RxEvent),
		"VirtIONet RX");

	return true;
}

void
VirtIONetDevice::MonitorScreen(Monitor *, TextScreen& screen)
{
	int y = 0;

	screen.Clear();

	y = MonitorScreenDev(screen, y);

	screen.Puts(0, y, "MACAddress:");
	screen.Puts(12, y, macaddr_str.c_str());
	y++;

	y++;
	for (const auto& q : vqueues) {
		y = MonitorScreenVirtQueue(screen, y, q);
		y++;
		y = MonitorScreenVirtQDesc(screen, y, q);
		y++;
	}
}

// QUEUE_READY が変化したら呼ばれる。
void
VirtIONetDevice::QueueReadyChanged(VirtQueue *q)
{
	if (q->idx == 0) {
		// 受信開始/停止
		if (q->GetReady() != 0) {
			hostnet->EnableRx(true);
		} else {
			scheduler->StopEvent(event);
			hostnet->EnableRx(false);
		}
	}
}

// ディスクリプタを一つ処理する。
void
VirtIONetDevice::ProcessDesc(VirtIOReq& req)
{
	VirtQueue *q = req.q;
	uint8 flags;
	uint8 gso_type;
	uint16 hdr_len;
	uint16 gso_size;
	uint16 csum_start;
	uint16 csum_offset;
	uint16 num_buffers;

	if (q->idx == 0) {
		putlog(0, "QUEUE_NOTIFY on rx queue (VQ%u)?", q->idx);
		return;
	}

	flags       = ReqReadU8(req);
	gso_type    = ReqReadU8(req);
	hdr_len     = ReqReadLE16(req);
	gso_size    = ReqReadLE16(req);
	csum_start  = ReqReadLE16(req);
	csum_offset = ReqReadLE16(req);
	num_buffers = ReqReadLE16(req);
	putlog(2, "req.hdr $%08x: flags=%02x gso=%02x hdr_len=%04x gso_size=%04x "
		"csum_start=%04x offset=%04x num=%04x",
		req.rbuf[0].addr,
		flags,
		gso_type,
		hdr_len,
		gso_size,
		csum_start,
		csum_offset,
		num_buffers);

	NetPacket tx_packet;
	uint32 len = req.rremain();
	tx_packet.Resize(len);
	uint32 rest = ReqReadRegion(req, tx_packet.Data(), len);
	tx_packet.Resize(len - rest);

	hostnet->Tx(tx_packet);
}

// これは Host スレッドから呼ばれる
void
VirtIONetDevice::HostRxCallback(uint32 dummy)
{
	// スレッドを超えるためにメッセージを投げる
	scheduler->SendMessage(MessageID::HOSTNET_RX(0));
}

// パケット受信通知
void
VirtIONetDevice::RxMessage(MessageID msgid_, uint32 arg)
{
	VirtQueue *q = &vqueues[0];

	if (__predict_false(q->GetReady() == 0)) {
		return;
	}

	// 受信イベントが止まっていれば動かす。
	// 受信イベントがすでに動いていれば、1パケット受信完了後に
	// ホストキューを確認するので、ここでは何もしなくてよい。
	if (event->IsRunning() == false) {
		event->time = 0;
		scheduler->StartEvent(event);
	}
}

// 受信ループイベント。
void
VirtIONetDevice::RxEvent(Event *ev)
{
	VirtQueue *q = &vqueues[0];

	if (__predict_false(q->GetReady() == 0)) {
		return;
	}

	// 空きディスクリプタがなければ、空くまでポーリングで待ち続ける。
	uint16 avail_idx = ReadLE16(q->driver + 2);
	if (q->last_avail_idx == avail_idx) {
		event->time = 500_usec;
		putlog(2, "%s: wait %u usec", __func__,
			(uint)tsec_to_usec(event->time));
		scheduler->RestartEvent(event);
		return;
	}

	// ディスクリプタに空きが出来たので、ホストキューから1パケット取り出す。
	assert(rx_packet.Length() == 0);
	if (hostnet->Rx(&rx_packet) == false) {
		// 取り出せなくなったら、ここでイベントループを終了。
		rx_packet.Clear();
		putlog(2, "Rx no more rx_packet");
		return;
	}

	putlog(2, "%s: %u bytes from host queue", __func__, rx_packet.Length());

	// rx_packet が埋まって、空きディスクリプタがあるので受信処理へ。
	Rx(q);

	rx_packet.Clear();
	event->time = 500_usec;
	scheduler->RestartEvent(event);
}

// 受信。(rx_packet が埋まっていてかつ空きディスクリプタがある状態で呼ぶこと)
void
VirtIONetDevice::Rx(VirtQueue *q)
{
	assert(rx_packet.Length() != 0);
	assert(q->last_avail_idx != ReadLE16(q->driver + 2));

	// XXX どこでやるのがいいか
	// VirtIO は FCS を含まない 1514 バイトが上限と決まっているが、
	// HostNet の下の受信ドライバが FCS を含めているかどうかが分からない。
	// 仕方ないので 1514 バイトを超える時だけ取り除く…。
	if (rx_packet.Length() > 1514) {
		rx_packet.Resize(1514);
	}

	VirtIOReq req;
	req.q = q;
	StartDesc(req);

	// ヘッダを用意。
	virtio_net_hdr hdr;
	memset(&hdr, 0, sizeof(hdr));
	hdr.hdr_len = htole16(sizeof(hdr));
	// num_buffers は実際に使ったセグメント数らしい。
	uint nseg = 0;
	uint32 len = 0;
	for (; nseg < req.wbuf.size(); nseg++) {
		len += req.wbuf[nseg].len;
		if (len >= sizeof(hdr) + rx_packet.Length()) {
			break;
		}
	}
	hdr.num_buffers = htole16(nseg);

	// ヘッダを書き込む。
	ReqWriteRegion(req, (const uint8 *)&hdr, sizeof(hdr));

	// パケット本体を書き込む。
	ReqWriteRegion(req, rx_packet.Data(), rx_packet.Length());

	CommitDesc(req);

	// ここは VM スレッドなのでそのまま割り込みを上げる。
	Done(q);
}

// この宛先アドレスを受信するかどうか。
// これは HostNet スレッドで呼ばれる。
int
VirtIONetDevice::HWAddrFilter(const MacAddr& dstaddr) const
{
	if (dstaddr.IsUnicast()) {
		if (dstaddr != macaddr) {
			return HPF_DROP_UNICAST;
		}
	}
	// マルチキャストをホストデバイス側でフィルタする機構は
	// CTRLQ 内にあるが、NetBSD のドライバは CTRLQ を実装していない。

	return HPF_PASS;
}

const char *
VirtIONetDevice::GetFeatureName(uint feature) const
{
	static std::pair<uint, const char *> names[] = {
		{ VIRTIO_NET_F_CSUM,			"CSUM" },
		{ VIRTIO_NET_F_GUEST_CSUM,		"G_CSUM" },
		{ VIRTIO_NET_F_CTRL_GUEST_OFFLOADS,	"C_G_OFFL" },
		{ VIRTIO_NET_F_MTU,				"MTU" },
		{ VIRTIO_NET_F_MAC,				"MAC" },
		{ VIRTIO_NET_F_GUEST_TSO4,		"G_TSO4" },
		{ VIRTIO_NET_F_GUEST_TSO6,		"G_TSO6" },
		{ VIRTIO_NET_F_GUEST_ECN,		"G_ECN" },
		{ VIRTIO_NET_F_GUEST_UFO,		"G_UFO" },
		{ VIRTIO_NET_F_HOST_TSO4,		"H_TSO4" },
		{ VIRTIO_NET_F_HOST_TSO6,		"H_TSO6" },
		{ VIRTIO_NET_F_HOST_ECN,		"H_ECN" },
		{ VIRTIO_NET_F_HOST_UFO,		"H_UFO" },
		{ VIRTIO_NET_F_MRG_RXBUF,		"MRG_RX" },
		{ VIRTIO_NET_F_STATUS,			"STATUS" },
		{ VIRTIO_NET_F_CTRL_VQ,			"CTL_VQ" },
		{ VIRTIO_NET_F_CTRL_RX,			"CTL_RX" },
		{ VIRTIO_NET_F_CTRL_VLAN,		"CTL_VLAN" },
		{ VIRTIO_NET_F_GUEST_ANNOUNCE,	"G_ANN" },
		{ VIRTIO_NET_F_MQ,				"MQ" },
		{ VIRTIO_NET_F_CTRL_MAC_ADDR,	"CTL_MAC" },
		{ VIRTIO_NET_F_RSC_EXT,			"RSC_EXT" },
		{ VIRTIO_NET_F_STANDBY,			"STANDBY" },
	};

	for (auto& p : names) {
		if (feature == p.first) {
			return p.second;
		}
	}
	return inherited::GetFeatureName(feature);
}

// デバイス構成レイアウトを書き出す。
void
VirtIONetConfigWriter::WriteTo(uint8 *dst) const
{
	MemoryStreamLE mem(dst);

	for (auto c : mac) {
		mem.Write1(c);
	}
	mem.Write2(status);
	mem.Write2(max_virtqueue_pairs);
	mem.Write2(mtu);
}
