#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>

#include "types.h"
#include "netrpc.h"

#define _MAX(a, b) ((a) > (b) ? (a) : (b))

static s32 _send(netrpc_ctxt_t *ctxt, const void *data, u32 length)
{
	return send(ctxt->socket, (char *)data, length, 0);
}

static s32 _recv(netrpc_ctxt_t *ctxt, const void *buffer, u32 maxlength)
{
	return recv(ctxt->socket, (char *)buffer, maxlength, 0);
}

static void _fill_header(netrpc_ctxt_t *ctxt, struct rpc_header *h, u32 cmd)
{
	ctxt->tag++;
	h->cmd = _ES32(cmd);
	h->tag = _ES32(ctxt->tag);
}

static s64 _check_result(netrpc_ctxt_t *ctxt, struct rpc_header *h, u32 cmd)
{
	if(_ES32(h->cmd) != cmd)
	{
		printf("bad command (0x%08x)\n", _ES32(h->cmd));
		return -1;
	}
	
	if(_ES32(h->tag) != ctxt->tag)
	{
		printf("bad tag (0x%08x)\n", _ES32(h->tag));
		return -1;
	}
	
	return _ES64(h->reply.retcode);
}

s32 netrpc_connect(netrpc_ctxt_t *ctxt, u32 ip, u32 port)
{
	struct sockaddr_in sa;
	
	ctxt->socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
	
	sa.sin_port = htons(port);
	sa.sin_family = AF_INET;
	sa.sin_addr.s_addr = ip;
	
	if(connect(ctxt->socket, (const struct sockaddr *)&sa, sizeof(struct sockaddr_in)) == -1)
		return 0;
	
	ctxt->tag = 0;
	
	return 1;
}

void netrpc_close(netrpc_ctxt_t *ctxt)
{
	close(ctxt->socket);
}

s64 netrpc_ping(netrpc_ctxt_t *ctxt)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_PING);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_PING);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}

s64 netrpc_readmem(netrpc_ctxt_t *ctxt, u64 addr, void *buffer, u32 size, u32 *rdsize)
{
	u32 fsize = sizeof(struct rpc_header) + size;
	struct rpc_header *h = (struct rpc_header *)malloc(fsize);
	
	_fill_header(ctxt, h, RPC_READMEM);
	
	//Memop params.
	h->memop.addr = _ES64(addr);
	h->memop.size = _ES32(size);
	
	_send(ctxt, h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, h, fsize);
	if(res >= 16)
	{
		if(_check_result(ctxt, h, RPC_READMEM) == 0)
		{
			res -= 16; //Header len.
			memcpy(buffer, h->reply.retdata, res);
			free(h);
			if(rdsize != NULL)
				*rdsize = res;
			return 0;
		}
		else
			printf("bad read size (%d)\n", size);
	}
	else
		printf("bad reply size (%d)\n", res);
	
	free(h);
	return -1;
}

s64 netrpc_memcpy_in(netrpc_ctxt_t *ctxt, void *buffer, u64 addr, u32 length)
{
	u32 blk, offset = 0;
	s64 res = 0;
	u8 *pbuffer = (u8 *)buffer;
	
	while(length != 0)
	{
		blk = length;
		if(blk > 1024)
			blk = 1024;
		if((res = netrpc_readmem(ctxt, addr + offset, pbuffer + offset, blk, NULL)) != 0)
			return res;
		offset += blk;
		length -= blk;
	}
	
	return offset;
}

s64 netrpc_writemem(netrpc_ctxt_t *ctxt, u64 addr, const void *buffer, u32 size)
{
	u32 fsize = sizeof(struct rpc_header) + size;
	struct rpc_header *h = (struct rpc_header *)malloc(fsize);
	
	_fill_header(ctxt, h, RPC_WRITEMEM);
	
	//Memop params.
	h->memop.addr = _ES64(addr);
	h->memop.size = _ES32(size);
	
	//Data.
	memcpy(h->memop.data, buffer, size);
	
	_send(ctxt, h, fsize);
	
	s32 res = _recv(ctxt, h, sizeof(struct rpc_header));
	if(res >= 16)
	{
		s64 res2 = _check_result(ctxt, h, RPC_WRITEMEM);
		if(res2 == -1)
			printf("bad write size (%d)\n", size);
		free(h);
		return res2;
	}
	
	printf("bad reply size (%d)\n", res);
	free(h);
	return -1;
}

s64 netrpc_memcpy_out(netrpc_ctxt_t *ctxt, u64 addr, void *buffer, u32 length)
{
	u32 blk, offset = 0;
	s64 res = 0;
	u8 *pbuffer = (u8 *)buffer;
	
	while(length != 0)
	{
		blk = length;
		if(blk > 1024)
			blk = 1024;
		if((res = netrpc_writemem(ctxt, addr + offset, pbuffer + offset, blk)) != 0)
			return res;
		offset += blk;
		length -= blk;
	}
	
	return res;
}

s64 netrpc_hvcall(netrpc_ctxt_t *ctxt, u64 code, const u64 arg_in[], u32 numin, u64 arg_out[], u32 numout)
{
	if(numin > 8)
	{
		printf("too many in args (%d)\n", numin);
		return -100;
	}
	
	if(numout > 7)
	{
		printf("too many out args (%d)\n", numout);
		return -100;
	}
	
	s32 i;
	u32 fsize = sizeof(struct rpc_header) + PS3_U64_SIZE * _MAX(numin, numout);
	struct rpc_header *h = (struct rpc_header *)malloc(fsize);
	
	_fill_header(ctxt, h, RPC_HVCALL);
	
	//Hvcall params.
	h->hvcall.code = _ES64(code);
	h->hvcall.numin = _ES32(numin);
	h->hvcall.numout = _ES32(numout);
	
	//Regs.
	if(arg_in != NULL)
	{
		for(i = 0; i < numin; i++)
			h->hvcall.regs[i] = _ES64(arg_in[i]);
	}
	
	_send(ctxt, h, fsize);
	
	s32 res = _recv(ctxt, h, fsize);
	if(res >= 16)
	{
		s64 res2 = _check_result(ctxt, h, RPC_HVCALL);
		if(arg_out != NULL)
		{
			u64 *pout = (u64 *)h->reply.retdata;
			for(i = 0; i < numout; i++)
				arg_out[i] = _ES64(pout[i]);
		}
		free(h);
		return res2;
	}
	
	printf("bad reply size (%d)\n", res);
	free(h);
	return -1;
}

s64 netrpc_addmmio(netrpc_ctxt_t *ctxt, u64 start, u32 size)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_ADDMMIO);
	
	//Addmmio params.
	h.addmmio.start = _ES64(start);
	h.addmmio.size = _ES32(size);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_ADDMMIO);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}

s64 netrpc_delmmio(netrpc_ctxt_t *ctxt, u64 start)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_DELMMIO);
	
	//Delmmio params.
	h.delmmio.start = _ES64(start);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_DELMMIO);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}

s64 netrpc_clrmmio(netrpc_ctxt_t *ctxt)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_CLRMMIO);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_CLRMMIO);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}

s64 netrpc_memset(netrpc_ctxt_t *ctxt, u64 addr, u64 value)
{
	return -1;
}

s64 netrpc_vector(netrpc_ctxt_t *ctxt, u64 vec0, u64 vec1, u64 copy_dst, u64 copy_src, u32 copy_size)
{
	return -1;
}

s64 netrpc_sync_before_exec(netrpc_ctxt_t *ctxt, u64 addr, u32 size)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_SYNC);
	
	//Sync params.
	h.sync.addr = _ES64(addr);
	h.sync.size = _ES32(size);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_SYNC);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}

s64 netrpc_call(netrpc_ctxt_t *ctxt, u64 addr, u64 arg1, u64 arg2, u64 arg3, u64 arg4, u64 arg5, u64 arg6, u64 arg7, u64 arg8)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_CALL);
	
	//Call params.
	h.call.addr = _ES64(addr);
	
	//Copy arguments.
	h.call.args[0] = _ES64(arg1);
	h.call.args[1] = _ES64(arg2);
	h.call.args[2] = _ES64(arg3);
	h.call.args[3] = _ES64(arg4);
	h.call.args[4] = _ES64(arg5);
	h.call.args[5] = _ES64(arg6);
	h.call.args[6] = _ES64(arg7);
	h.call.args[7] = _ES64(arg8);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_CALL);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}

s64 netrpc_eieio(netrpc_ctxt_t *ctxt)
{
	struct rpc_header h;
	
	_fill_header(ctxt, &h, RPC_EIEIO);
	
	_send(ctxt, &h, sizeof(struct rpc_header));
	
	s32 res = _recv(ctxt, &h, sizeof(struct rpc_header));
	if(res >= 16)
		return _check_result(ctxt, &h, RPC_EIEIO);
	
	printf("bad reply size (%d)\n", res);
	return -1;
}
