/*!
 * @file tcp.c
 * @brief Definitions for functionality that handles TCP client operations.
 */
#include "precomp.h"
#include "common_metapi.h"
#include "tcp.h"

/*!
 * @brief Writes data from the remote half of the channel to the established connection.
 * @param channel Pointer to the channel to write to.
 * @param request Pointer to the request packet.
 * @param context Pointer to the channel's context.
 * @param buffer Buffer containing the data to write to the channel.
 * @param bufferSize Size of the buffer indicating how many bytes to write.
 * @param bytesWritten Pointer that receives the number of bytes written to the \c channel.
 * @returns Indication of success or failure.
 * @retval ERROR_SUCCESS writing the data completed successfully.
 */
DWORD tcp_channel_client_write(Channel *channel, Packet *request, LPVOID context, LPVOID buffer, DWORD bufferSize, LPDWORD bytesWritten)
{
	DWORD dwResult = ERROR_SUCCESS;
	TcpClientContext * ctx = NULL;
	LONG written = 0;

	do
	{
		dprintf("[TCP] tcp_channel_client_write. channel=0x%08X, buffsize=%d", channel, bufferSize);

		ctx = (TcpClientContext *)context;
		if (!ctx)
		{
			BREAK_WITH_ERROR("[TCP] tcp_channel_client_write. ctx == NULL", ERROR_INVALID_HANDLE);
		}

		written = send(ctx->fd, buffer, bufferSize, 0);

		if (written == SOCKET_ERROR)
		{
			dwResult = WSAGetLastError();

			if (dwResult == WSAEWOULDBLOCK)
			{
				struct timeval tv = { 0 };
				fd_set set = { 0 };
				DWORD res = 0;

				dprintf("[TCP] tcp_channel_client_write. send returned WSAEWOULDBLOCK, waiting until we can send again...");

				while (TRUE)
				{
					tv.tv_sec = 0;
					tv.tv_usec = 1000;

					FD_ZERO(&set);
					FD_SET(ctx->fd, &set);

					res = select(0, NULL, &set, NULL, &tv);
					if (res > 0)
					{
						dwResult = ERROR_SUCCESS;
						break;
					}
					else if (res == SOCKET_ERROR)
					{
						dwResult = WSAGetLastError();
						break;
					}

					Sleep(100);
				}

				if (dwResult == ERROR_SUCCESS)
				{
					continue;
				}
				else
				{
					dprintf("[TCP] tcp_channel_client_write. select == SOCKET_ERROR. dwResult=%d", dwResult);
				}
			}

			written = 0;
			dprintf("[TCP] tcp_channel_client_write. written == SOCKET_ERROR. dwResult=%d", dwResult);
		}

		if (bytesWritten)
		{
			*bytesWritten = written;
		}

	} while (0);

	dprintf("[TCP] tcp_channel_client_write. finished. dwResult=%d, written=%d", dwResult, written);

	return dwResult;
}

/*!
 * @brief Closes the established connection and cleans up stale state.
 * @param channel Pointer to the channel to be closed.
 * @param request Pointer to the request packet.
 * @param context Pointer to the channel's context.
 * @returns indication of success or failure.
 * @retval ERROR_SUCCESS the channel was closed successfully.
 */
DWORD tcp_channel_client_close(Channel *channel, Packet *request, LPVOID context)
{
	TcpClientContext *ctx = (TcpClientContext *)context;

	dprintf( "[TCP] tcp_channel_client_close. channel=0x%08X, ctx=0x%08X", channel, ctx );

	if (ctx)
	{
		// Set the context channel to NULL so we don't try to close the
		// channel (since it's already being closed)
		ctx->channel = NULL;

		// Free the context
		free_tcp_client_context(ctx);

		// Set the native channel operations context to NULL
		met_api->channel.set_native_io_context(channel, NULL);
	}

	return ERROR_SUCCESS;
}

/*!
 * @brief Callback for when there is data available on the local side of the TCP client connection.
 * @param remote Pointer to the remote that will receive the data.
 * @param ctx Pointer to the TCP client context.
 * @returns Indication of success or failure.
 * @retval ERROR_SUCCESS This value is always returned.
 */
DWORD tcp_channel_client_local_notify(Remote * remote, TcpClientContext * ctx)
{
	struct timeval tv = { 0 };
	fd_set set = { 0 };
	UCHAR  buf[16384] = { 0 };
	LONG   dwBytesRead = 0;

	// We select in a loop with a zero second timeout because it's possible
	// that we could get a recv notification and a close notification at once,
	// so we need some way to make sure that we see them both, otherwise the
	// event handle wont get re set to notify us.
	do
	{
		// Reset the notification event
		ResetEvent(ctx->notify);

		FD_ZERO(&set);
		FD_SET(ctx->fd, &set);

		tv.tv_sec = 0;
		tv.tv_usec = 0;

		// Read data from the client connection
		dwBytesRead = recv(ctx->fd, buf, sizeof(buf), 0);

		if (dwBytesRead == SOCKET_ERROR)
		{
			DWORD dwError = WSAGetLastError();

			// WSAECONNRESET: The connection was forcibly closed by the remote host.
			// WSAECONNABORTED: The connection was terminated due to a time-out or other failure.
			if (dwError == WSAECONNRESET || dwError == WSAECONNABORTED)
			{
				dprintf("[TCP] tcp_channel_client_local_notify. [error] closing down channel gracefully. WSAGetLastError=%d", dwError);
				// By setting bytesRead to zero, we can ensure we close down the channel gracefully...
				dwBytesRead = 0;
			}
			else if (dwError == WSAEWOULDBLOCK)
			{
				dprintf("[TCP] tcp_channel_client_local_notify. channel=0x%08X. recv generated a WSAEWOULDBLOCK", ctx->channel);
				// break and let the scheduler notify us again if needed.
				break;
			}
			else
			{
				dprintf("[TCP] tcp_channel_client_local_notify. [error] channel=0x%08X read=0x%.8x (ignored). WSAGetLastError=%d", ctx->channel, dwBytesRead, dwError);
				// we loop again because bytesRead is -1.
			}
		}

		if (dwBytesRead == 0)
		{
			dprintf("[TCP] tcp_channel_client_local_notify. [closed] channel=0x%08X read=0x%.8x", ctx->channel, dwBytesRead);

			// Set the native channel operations context to NULL
			met_api->channel.set_native_io_context(ctx->channel, NULL);

			// Sleep for a quarter second
			Sleep(250);

			// Free the context
			free_tcp_client_context(ctx);

			// Stop processing
			break;
		}
		else if (dwBytesRead > 0)
		{
			if (ctx->channel)
			{
				dprintf("[TCP] tcp_channel_client_local_notify. [data] channel=0x%08X read=%d", ctx->channel, dwBytesRead);
				met_api->channel.write(ctx->channel, ctx->remote, NULL, 0, buf, dwBytesRead, 0);
			}
			else
			{
				dprintf("[TCP] tcp_channel_client_local_notify. [data] channel=<invalid> read=0x%.8x", dwBytesRead);
			}
		}

	} while (select(1, &set, NULL, NULL, &tv) > 0);

	return ERROR_SUCCESS;
}

/*!
 * @brief Allocates a streaming TCP channel.
 * @param remote Pointer to the remote instance.
 * @param packet Pointer to the request packet.
 * @returns Indication of success or failure.
 * @retval ERROR_SUCCESS Opening of the channel succeeded.
 * @remarks The request packet needs to contain:
 *            - \c TLV_TYPE_HOST_NAME - Host to connnect to.
 *            - \c TLV_TYPE_PORT  - Port to connnect to.
 */
DWORD request_net_tcp_client_channel_open(Remote *remote, Packet *packet)
{
	Channel *channel = NULL;
	TcpClientContext *ctx = NULL;
	Packet *response = met_api->packet.create_response(packet);
	DWORD result = ERROR_SUCCESS;
	LPCSTR host;
	DWORD port;

	do
	{
		// No response packet?
		if (!response)
		{
			break;
		}

		// Extract the hostname and port that we are to connect to
		host = met_api->packet.get_tlv_value_string(packet, TLV_TYPE_PEER_HOST);
		port = met_api->packet.get_tlv_value_uint(packet, TLV_TYPE_PEER_PORT);

		// Open the TCP channel
		if ((result = create_tcp_client_channel(remote, host, (USHORT)(port & 0xffff), &channel, &ctx)) != ERROR_SUCCESS)
		{
			break;
		}

		// Set the channel's identifier on the response
		met_api->packet.add_tlv_uint(response, TLV_TYPE_CHANNEL_ID, met_api->channel.get_id(channel));
		net_tlv_pack_local_addrinfo(ctx, response);

	} while (0);

	// Transmit the response
	met_api->packet.transmit_response(result, remote, response);

	return ERROR_SUCCESS;
}

/*!
 * @brief Creates a connection to a remote host and builds a logical channel to represent it.
 * @param remote Pointer to the remote instance.
 * @param remoteHost The remote host to connect to.
 * @param remoteHost The remote port to connect to.
 * @param outChannel Pointer that will receive the newly created channel.
 * @param outContext Pointer that will receive the newly created tcp client context.
 * @returns Indication of success or failure.
 * @retval ERROR_SUCCESS Creation of the TCP client was successful.
 */
DWORD create_tcp_client_channel(Remote *remote, LPCSTR remoteHost, USHORT remotePort, Channel **outChannel, TcpClientContext **outContext)
{
	StreamChannelOps chops;
	TcpClientContext *ctx = NULL;
	DWORD result = ERROR_SUCCESS;
	Channel *channel = NULL;
	struct sockaddr_in s;
	SOCKET clientFd = 0;

	if (outChannel)
	{
		*outChannel = NULL;
	}
	if (outContext)
	{
		*outContext = NULL;
	}

	dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d", remoteHost, remotePort);

	do
	{
		// Allocate a client socket
		if ((clientFd = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, 0)) == INVALID_SOCKET)
		{
			clientFd = 0;
			result = GetLastError();
			break;
		}

		s.sin_family = AF_INET;
		s.sin_port = htons(remotePort);
		s.sin_addr.s_addr = inet_addr(remoteHost);

		// Resolve the host name locally
		if (s.sin_addr.s_addr == (DWORD)-1)
		{
			struct hostent *h;

			if (!(h = gethostbyname(remoteHost)))
			{
				result = GetLastError();
				break;
			}

			memcpy(&s.sin_addr.s_addr, h->h_addr, h->h_length);
		}

		dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d connecting...", remoteHost, remotePort);
		// Try to connect to the host/port
		if (connect(clientFd, (struct sockaddr *)&s, sizeof(s)) == SOCKET_ERROR)
		{
			result = WSAGetLastError();
			dprintf("[TCP] create client failed host=%s, port=%d error=%u 0x%x", remoteHost, remotePort, result, result);
			break;
		}

		dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d connected!", remoteHost, remotePort);
		// Allocate the client context for tracking the connection
		if (!(ctx = (TcpClientContext *)malloc(sizeof(TcpClientContext))))
		{
			result = ERROR_NOT_ENOUGH_MEMORY;
			break;
		}

		// Initialize the context attributes
		memset(ctx, 0, sizeof(TcpClientContext));

		ctx->remote = remote;
		ctx->fd = clientFd;

		// Initialize the channel operations structure
		memset(&chops, 0, sizeof(chops));

		chops.native.context = ctx;
		chops.native.write = tcp_channel_client_write;
		chops.native.close = tcp_channel_client_close;

		dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d creating the channel", remoteHost, remotePort);
		// Allocate an uninitialized channel for associated with this connection
		if (!(channel = met_api->channel.create_stream(0, 0, &chops)))
		{
			result = ERROR_NOT_ENOUGH_MEMORY;
			break;
		}

		// Save the channel context association
		ctx->channel = channel;

		// Finally, create a waitable event and insert it into the scheduler's
		// waitable list
		dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d creating the notify", remoteHost, remotePort);
		if ((ctx->notify = WSACreateEvent()))
		{
			WSAEventSelect(ctx->fd, ctx->notify, FD_READ | FD_CLOSE);
			dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d created the notify %.8x", remoteHost, remotePort, ctx->notify);

			met_api->scheduler.insert_waitable(ctx->notify, ctx, NULL, (WaitableNotifyRoutine)tcp_channel_client_local_notify, NULL);
		}

	} while (0);

	dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d all done", remoteHost, remotePort);

	// Clean up on failure
	if (result != ERROR_SUCCESS)
	{
		dprintf("[TCP] create_tcp_client_channel. host=%s, port=%d cleaning up failed connection", remoteHost, remotePort);
		if (ctx)
		{
			free_tcp_client_context(ctx);
			ctx = NULL;
		}

		if (clientFd)
		{
			closesocket(clientFd);
		}

		channel = NULL;
	}

	if (outChannel)
	{
		*outChannel = channel;
	}
	if (outContext)
	{
		*outContext = ctx;
	}

	return result;
}

/*!
 * @brief Deallocates and cleans up the attributes of a socket context.
 * @ctx Pointer to the socket context to free.
 */
VOID free_socket_context(SocketContext *ctx)
{
	dprintf("[TCP] free_socket_context. ctx=0x%08X", ctx);

	// Close the socket and notification handle
	if (ctx->fd)
	{
		closesocket(ctx->fd);
		ctx->fd = 0;
	}

	if (ctx->channel)
	{
		met_api->channel.close(ctx->channel, ctx->remote, NULL, 0, NULL);
		ctx->channel = NULL;
	}

	if (ctx->notify)
	{
		dprintf("[TCP] free_socket_context. remove_waitable ctx=0x%08X notify=0x%08X", ctx, ctx->notify);
		// The scheduler calls CloseHandle on our WSACreateEvent() for us
		met_api->scheduler.signal_waitable(ctx->notify, SchedulerStop);
		ctx->notify = NULL;
	}

	// Free the context
	free(ctx);
}

/*!
 * @brief Shuts the socket down for either reading or writing.
 * @param remote Pointer to the remote instance.
 * @param packet Pointer to the packet.
 * @remark The contents of the \c packet indicate whether to stop reading or writing.
 * @returns Indication of success or failure.
 * @retval ERROR_SUCCESS This value is always returned.
 */
DWORD request_net_socket_tcp_shutdown(Remote *remote, Packet *packet)
{
	DWORD dwResult = ERROR_SUCCESS;
	Packet * response = NULL;
	SocketContext * ctx = NULL;
	Channel * channel = NULL;
	DWORD cid = 0;
	DWORD how = 0;

	do
	{
		dprintf("[TCP] entering request_net_socket_tcp_shutdown");
		response = met_api->packet.create_response(packet);
		if (!response)
		{
			BREAK_WITH_ERROR("[TCP] request_net_socket_tcp_shutdown. response == NULL", ERROR_NOT_ENOUGH_MEMORY);
		}

		cid = met_api->packet.get_tlv_value_uint(packet, TLV_TYPE_CHANNEL_ID);
		how = met_api->packet.get_tlv_value_uint(packet, TLV_TYPE_SHUTDOWN_HOW);

		channel = met_api->channel.find_by_id(cid);
		if (!response)
		{
			BREAK_WITH_ERROR("[TCP] request_net_socket_tcp_shutdown. channel == NULL", ERROR_INVALID_HANDLE);
		}

		dprintf("[TCP] request_net_socket_tcp_shutdown. channel=0x%08X, cid=%d", channel, cid);

		ctx = met_api->channel.get_native_io_context(channel);
		if (!ctx)
		{
			BREAK_WITH_ERROR("[TCP] request_net_socket_tcp_shutdown. ctx == NULL", ERROR_INVALID_HANDLE);
		}

		if (shutdown(ctx->fd, how) == SOCKET_ERROR)
		{
			BREAK_ON_WSAERROR("[TCP] request_net_socket_tcp_shutdown. shutdown failed");
		}

		// sf: we dont seem to need to call this here, as the channels tcp_channel_client_local_notify() will
		// catch the socket closure and call free_socket_context() for us, due the the FD_READ|FD_CLOSE flags
		// being passed to WSAEventSelect for the notify event in create_tcp_client_channel().
		// This avoids a double call (from two different threads) and subsequent access violation in some edge cases.
		//free_socket_context( ctx );

	} while (0);

	met_api->packet.transmit_response(dwResult, remote, response);

	dprintf("[TCP] leaving request_net_socket_tcp_shutdown");

	return ERROR_SUCCESS;
}
