#include <sys/types.h>
#include <assert.h>
#include <stdint.h>
#include "utf8.h"

/**
 * \brief determines the byte length of a UTF-8 character sequence
 * @param buf pointer to first byte of sequence
 * @returns length of sequence in bytes
 *
 * This function does not check the validity of the sequence. In fact it will
 * only look at the very first character. It will return values in the range
 * from 0 to 4 where 0 indicates that a forbidden UTF-8 byte (0xc0, 0xc1,
 * 0xf5..0xff) was encountered.
 */
unsigned int
utf8_char_length(const unsigned char *buf)
{
	if ((*buf & 0x80) == 0) {
		return 1;
	} else if ((*buf & 0xe0) == 0xc0) {
		if ((*buf & 0xfe) == 0xc0)
			return 0;
		return 2;
	} else if ((*buf & 0xf0) == 0xe0) {
		return 3;
	} else if ((*buf & 0xf8) == 0xf0) {
		if (*buf >= 0xf5)
			return 0;
		return 4;
	} else {
		return 0;
	}
}

/**
 * \brief check if a given character sequence starts with valid UTF-8 characters
 * @param buf pointer to start of sequence
 * @param len total length of sequence
 * @returns number of bytes that represent valid UTF-8 characters
 */
size_t
utf8_valid(const unsigned char *buf, const size_t len)
{
	size_t pos = 0;

	while (pos < len) {
		size_t seqlen = 0;

		seqlen = utf8_char_length(buf + pos);

		if (seqlen == 0)
			return pos;

		if (seqlen == 1) {
			pos++;
			continue;
		}

		if (len - seqlen < pos)
			return pos;

		if (utf8_char_value(buf + pos) == 0xffffffff)
			return pos;

		pos += seqlen;
	}

	return pos;
}

/**
 * \brief return the unicode character number for a given UTF-8 sequence
 * @param buf pointer to a UTF-8 sequence
 * @returns the unicode character number of 0xffffffff on error
 */
uint32_t
utf8_char_value(const unsigned char *buf)
{
	unsigned int len = utf8_char_length(buf);
	uint32_t ret = 0;
	unsigned int pos;
	unsigned int shiftbits = 0;

	if (len == 0)
		return 0xffffffff;

	for (pos = len; pos > 0; --pos) {
		if (pos > 0) {
			ret |= ((buf[pos] & 0x3f) << shiftbits);
			shiftbits += 6;
		} else {
			if (len == 1) {
				ret = *buf;
			} else {
				ret |= (*buf & ((1 << (7 - len)) - 1)) << shiftbits;
			}
		}
	}

	if ((ret >= 0xd800) && (ret <= 0xdfff))
		return 0xffffffff;

	return ret;
}

