#include "MDNS.h"

//#define _MDNS_DEBUG_

#ifdef _MDNS_DEBUG_

#define DEBUG_PRINT_FN(s) RLOG(s)

#define DEBUG_PRINT(msg, ...) do {	\
		String s = Format(msg, ##__VA_ARGS__, 0); \
		DEBUG_PRINT_FN(s); \
	} while(0)

#else

#define DEBUG_PRINT(msg, ...)

#endif

#define MDNS_NAME_REF				0xC000

#define MDNS_TYPE_AAAA				0x001C
#define MDNS_TYPE_A					0x0001
#define MDNS_TYPE_PTR				0x000C
#define MDNS_TYPE_SRV				0x0021
#define MDNS_TYPE_TXT				0x0010

#define MDNS_CLASS_IN				0x0001
#define MDNS_CLASS_IN_FLUSH_CACHE	0x8001

#define MDNS_ANSWERS_ALL			0x0F
#define MDNS_ANSWER_PTR				0x08
#define MDNS_ANSWER_TXT				0x04
#define MDNS_ANSWER_SRV				0x02
#define MDNS_ANSWER_A				0x01

static const IPV4Address MDNS_ADDR(224, 0, 0, 251);
static const int MDNS_MULTICAST_TTL = 1;
static const int MDNS_PORT = 5353;

// constructor
MDNSService::MDNSService()
{
	_hostName		= "";
	_ip				= IPV4Address();
	_name			= "";
	_proto			= "";
	_port			= 0;
	_ttl			= 0;
	_creationTime	= msecs();
}

// destructor
MDNSService::~MDNSService()
{
}

// copy and pick constructors
MDNSService::MDNSService(MDNSService const &s)
{
	operator=(s);
}

MDNSService::MDNSService(MDNSService && s)
{
	operator=(s);
}

// copy / pick
MDNSService const &MDNSService::operator=(MDNSService const & s)
{
	_ip				= s._ip;
	_hostName		= s._hostName;
	_name			= s._name;
	_proto			= s._proto;
	_port			= s._port;
	_ttl			= s._ttl;
	_creationTime	= s._creationTime;
	_texts			<<= s._texts;
	return *this;
}

MDNSService &MDNSService::operator=(MDNSService && s)
{
	_ip				= s._ip;
	_hostName		= s._hostName;
	_name			= s._name;
	_proto			= s._proto;
	_port			= s._port;
	_ttl			= s._ttl;
	_creationTime	= s._creationTime;
	_texts			= pick(s._texts);
	return *this;
}

// get texts, packed in a string
String MDNSService::GetTextsPacked(void) const
{
	String res;
	for (int i = 0; i < _texts.GetCount(); i++)
	{
		String s = _texts.GetKey(i) + "=" + _texts[i];
		res << (uint8_t)s.GetCount();
		res << s;
	}
	res.Cat((uint8_t)0);
	return res;
}

// add a text
MDNSService &MDNSService::AddText(String const &key, String const &val)
{
	_texts.GetAdd(key) = val;
	return *this;
}

// remove a text
MDNSService &MDNSService::RemoveText(String const &key)
{
	_texts.RemoveKey(key);
	return *this;
}

// get a text value
String MDNSService::GetText(String const &key) const
{
	return _texts.Get(key, "");
}

// debugger helper
String MDNSService::ToString(void) const
{
	String res;
	res << _hostName;
	res << "@";
	res << _ip.ToString();
	res << ":";
	res << Format("%d", _port);
	
	return res;
}


//////////////////////////////////////////////////////////////////////////////////

// get reply ip from source one
IPV4Address MDNSClass::GetReplyIP(IPV4Address const &sourceIP) const
{
	// get all available interfaces IPs
	VectorMap<String, IPV4Address> interfaces = GetLocalIPV4Addresses();

	int nMatches = 0;
	int idx = -1;
	for (int i = 0; i < interfaces.GetCount(); i++)
	{
		IPV4Address &ip = interfaces[i];
		// skip local interface
		if (ip[3] == 127)
			continue;
		int n = 0;
		for (int k = 3; k >= 0; k--)
			if (sourceIP[k] == ip[k])
				n++;
			else
				break;
		if (n > nMatches)
		{
			nMatches = n;
			idx = i;
		}
	}
	if (idx != -1)
		return interfaces[idx];
	return IPV4Address();
}

// reply (either to a request or to the multicast address
bool MDNSClass::Reply(uint8_t mask, MDNSService const &service, IPV4Address const &ourIP, IPV4Address const &destIP)
{
	int i;
	if (mask == 0)
		return true;

	StringStream ss;

	DEBUG_PRINT("TX: mask:%01X, service:%s, proto:%s, port:%u\n", mask, service.GetName(), service.GetProtocol(), service.GetPort());

	String instanceName = _instanceName;
	size_t instanceNameLen = instanceName.GetCount();

	String hostName = _hostName;
	size_t hostNameLen = hostName.GetCount();

	// build service name with _
	String serviceName = "_" + service.GetName();
	size_t serviceNameLen = serviceName.GetCount();

	//build proto name with _
	String protoName = "_" + service.GetProtocol();
	size_t protoNameLen = protoName.GetCount();

	//local string
	String localName = "local";
	size_t localNameLen = localName.GetCount();

	//terminator
	char terminator[] = "\0";

	uint8_t answerCount = 0;
	for (i = 0;i < 4;i++)
		if (mask & (1 << i))
			answerCount++;

	// write the header
	uint8_t head[12] =
	{
		0x00, 0x00,			//ID = 0
		0x84, 0x00,			//Flags = response + authoritative answer
		0x00, 0x00,			//Question count
		0x00, answerCount,	//Answer count
		0x00, 0x00,			//Name server records
		0x00, 0x00,			//Additional records
	};
	ss.Put(head, 12);

	// PTR Response
	if (mask & 0x8)
	{
		// Send the Name field (ie. "_http._tcp.local")
		ss.Put((uint8_t)serviceNameLen);		// lenght of "_http"
		ss.Put(~serviceName, serviceNameLen);	// "_http"
		ss.Put((uint8_t)protoNameLen);			// lenght of "_tcp"
		ss.Put(~protoName, protoNameLen);		// "_tcp"
		ss.Put((uint8_t)localNameLen);			// lenght "local"
		ss.Put(~localName, localNameLen);		// "local"
		ss.Put(terminator);						// terminator

		//Send the type, class, ttl and rdata length
		uint8_t ptrDataLen = (uint8_t)instanceNameLen + (uint8_t)serviceNameLen + (uint8_t)protoNameLen + (uint8_t)localNameLen + 5; // 5 is four label sizes and the terminator
		uint8_t ptrAttrs[10] =
		{
			0x00, 0x0c,				// PTR record query
			0x00, 0x01,				// Class IN
			0x00, 0x00, 0x11, 0x94,	// TTL 4500
			0x00, ptrDataLen,		// RData length
		};
		ss.Put(ptrAttrs, 10);

		//Send the RData (ie. "My IOT device._http._tcp.local")
		ss.Put((uint8_t)instanceNameLen);		// lenght of "My IOT device"
		ss.Put(~instanceName, instanceNameLen);	// "My IOT device"
		ss.Put((uint8_t)serviceNameLen);		// lenght of "_http"
		ss.Put(~serviceName, serviceNameLen);	// "_http"
		ss.Put((uint8_t)protoNameLen);			// lenght of "_tcp"
		ss.Put(~protoName, protoNameLen);		// "_tcp"
		ss.Put((uint8_t)localNameLen);			// lenght "local"
		ss.Put(~localName, localNameLen);		// "local"
		ss.Put(terminator);						// terminator
	}

	//TXT Responce
	if (mask & 0x4)
	{
		//Send the name field (ie. "My IOT device._http._tcp.local")
		ss.Put((uint8_t)instanceNameLen);		// lenght of "My IOT device"
		ss.Put(~instanceName, instanceNameLen);	// "My IOT device"
		ss.Put((uint8_t)serviceNameLen);		// lenght of "_http"
		ss.Put(~serviceName, serviceNameLen);	// "_http"
		ss.Put((uint8_t)protoNameLen);			// lenght of "_tcp"
		ss.Put(~protoName, protoNameLen);		// "_tcp"
		ss.Put((uint8_t)localNameLen);			// lenght "local"
		ss.Put(~localName, localNameLen);		// "local"
		ss.Put(terminator);						// terminator

		//Send the type, class, ttl and rdata length
		String texts = service.GetTextsPacked();
		uint8_t txtDataLen = texts.GetCount();
		uint8_t txtAttrs[10] =
		{
			0x00, 0x10,				// TXT record query
			0x00, 0x01,				// Class IN
			0x00, 0x00, 0x11, 0x94,	// TTL 4500
			0x00, txtDataLen,		// RData length
		};
		ss.Put(txtAttrs, 10);

		//Send the RData
		ss.Put(~texts, txtDataLen);
	}

	//SRV Responce
	if (mask & 0x2)
	{
		//Send the name field (ie. "My IOT device._http._tcp.local")
		ss.Put((uint8_t)instanceNameLen);		// lenght of "My IOT device"
		ss.Put(~instanceName, instanceNameLen);	// "My IOT device"
		ss.Put((uint8_t)serviceNameLen);		// lenght of "_http"
		ss.Put(~serviceName, serviceNameLen);	// "_http"
		ss.Put((uint8_t)protoNameLen);			// lenght of "_tcp"
		ss.Put(~protoName, protoNameLen);		// "_tcp"
		ss.Put((uint8_t)localNameLen);			// lenght "local"
		ss.Put(~localName, localNameLen);		// "local"
		ss.Put(terminator);						// terminator

		//Send the type, class, ttl, rdata length, priority and weight
		uint8_t srvDataSize = (uint8_t)hostNameLen + (uint8_t)localNameLen + 3; // 3 is 2 lable size bytes and the terminator
		srvDataSize += 6; // Size of Priority, weight and port
		uint8_t srvAttrs[10] =
		{
			0x00, 0x21,				// Type SRV
			0x80, 0x01,				// Class IN, with cache flush
			0x00, 0x00, 0x00, 0x78,	// TTL 120
			0x00, srvDataSize,		// RData length
		};
		ss.Put(srvAttrs, 10);

		//Send the RData Priority weight and port
		uint16_t port = service.GetPort();
		uint8_t srvRData[6] =
		{
			0x00, 0x00,						// Priority 0
			0x00, 0x00,						// Weight 0
			(uint8_t)((port >> 8) & 0xFF),
			(uint8_t)(port & 0xFF)
		};
		ss.Put(reinterpret_cast<const char*>(srvRData), 6);
		//Send the RData (ie. "esp8266.local")
		ss.Put((uint8_t)hostNameLen);			// lenght of "esp8266"
		ss.Put(~hostName, hostNameLen);			// "esp8266"
		ss.Put((uint8_t)localNameLen);			// lenght "local"
		ss.Put(~localName, localNameLen);		// "local"
		ss.Put(terminator);						// terminator

	}

	// A Response
	if (mask & 0x1)
	{
		//Send the RData (ie. "esp8266.local")
		ss.Put((uint8_t)hostNameLen);			// lenght of "esp8266"
		ss.Put(~hostName, hostNameLen);			// "esp8266"
		ss.Put((uint8_t)localNameLen);			// lenght "local"
		ss.Put(~localName, localNameLen);		// "local"
		ss.Put(terminator);						// terminator

		uint32_t ip = ourIP;
		uint8_t aaaAttrs[10] =
		{
			0x00, 0x01,				// TYPE A
			0x80, 0x01,				// Class IN, with cache flush
			0x00, 0x00, 0x00, 0x78,	// TTL 120
			0x00, 0x04,				// DATA LEN
		};
		ss.Put(aaaAttrs, 10);

		// Send RData
		uint8_t aaaRData[4] =
		{
			(uint8_t)(ip & 0xFF),			// IP first octet
			(uint8_t)((ip >> 8) & 0xFF),	// IP second octet
			(uint8_t)((ip >> 16) & 0xFF),	// IP third octet
			(uint8_t)((ip >> 24) & 0xFF)	// IP fourth octet
		};
		ss.Put(aaaRData, 4);
	}

	udp.SendTo(destIP, MDNS_PORT, ss.GetResult());
	return true;
}

// constructor
MDNSClass::MDNSClass() : udp(MDNS_PORT)
{
	// we need our host name and IP address
	_hostName = TcpSocket::GetHostName();
	_instanceName = _hostName;
}

// destructor
MDNSClass::~MDNSClass()
{
}

// find a discovered service inside the discovered list
// by service, protocol and ip; if not found, just add one
// record at end of list and return its pointer
MDNSService &MDNSClass::FindAdd(String const &name, String const &proto, IPV4Address const &ip)
{
	String key = name + ":" + proto;
	DEBUG_PRINT("ADDING SERVICE %s\n", key);
	for (int i = 0; i < _discoveredServices.GetCount(); i++)
	{
		if (_discoveredServices.GetKey(i) != key)
			continue;
		if (_discoveredServices[i].GetIP() == ip)
			return _discoveredServices[i];
	}
	MDNSService &newService = _discoveredServices.Add(key);
	newService.SetIP(ip);
	return newService;
}

// check if we're interested in a service/protocol
bool MDNSClass::IsInteresting(String const &service, String const &protocol) const
{
	String key = service + ":" + protocol;
	return _interestingServices.Find(key) >= 0;
}


// some helpers
static inline uint8_t read8(StringStream &s)
{
	return s.Get8();
}

static inline uint16_t read16MSB(StringStream &s)
{
	uint16_t bh = read8(s);
	uint16_t bl = read8(s);
	return bh << 8 | bl;
}

static inline uint32_t read32MSB(StringStream &s)
{
	uint32_t wh = read16MSB(s);
	uint32_t wl = read16MSB(s);
	return wh << 16 | wl;
}

static inline uint16_t read(StringStream &ss, uint8_t *buf, uint16_t len)
{
	return ss.Get(buf, len);
}

static inline String readString(StringStream &ss, uint16_t len)
{
#ifdef PLATFORM_WIN32
	uint8_t *buf = (uint8_t *)alloca(len);
#else
	uint8_t buf[len];
#endif
	ss.Get(buf, len);
	return String(buf, len);
}

static IPV4Address readIP(StringStream &ss)
{
	uint32_t buf;
	ss.Get((uint8_t *)&buf, 4);
	return IPV4Address(buf);
}

// parse answer packet
bool MDNSClass::ParseAnswer(uint16_t const *hdr, StringStream &ss)
{
	DEBUG_PRINT("Reading answers RX: REQ, ID:%02x, Q:%02x, A:%02x, NS:%02x, ADD:%02x\n", hdr[0], hdr[2], hdr[3], hdr[4], hdr[5]);

	int numAnswers = hdr[3];

	// Assume that the PTR answer always comes first and that it is always accompanied by a TXT, SRV, AAAA (optional) and A answer in the same packet.
	if (numAnswers < 4)
	{
		DEBUG_PRINT("Expected a packet with 4 answers, returning\n");
		return false;
	}

	/*
		uint16_t answerPort = 0;
		uint8_t answerIp[4] = { 0, 0, 0, 0 };
		char answerHostName[255];
	*/
	String fullHostName;
	String answerServiceName;
	String answerProto;
	String answerHostName;
	uint16_t answerPort = 0;
	uint32_t answerTTL = 0;
	IPV4Address answerIP;
	VectorMap<String, String> answerTexts;

	uint8_t partsCollected = 0;

	while (numAnswers--)
	{
		// Read names
		Vector<String> names;
		do
		{
			uint8_t len = read8(ss);
			if (len & 0xC0)   // Compressed pointer (not supported)
			{
				len = read8(ss);
				break;
			}
			if (len == 0x00)   // Énd of names
				break;
			names.Add(readString(ss, len));
		}
		while (true);

		DEBUG_PRINT("Names : %s", names.ToString());

		// read answer type
		uint16_t answerType = read16MSB(ss);

		// read answer class -- not used here
		/* uint16_t answerClass = */
		read16MSB(ss);

		// read TTL -- use largest one as service TTL
		uint32_t ttl = read32MSB(ss);
		if (ttl > answerTTL)
			answerTTL = ttl;

		// read answer length
		uint16_t answerRdlength = read16MSB(ss);

		DEBUG_PRINT("type: %04x, rdlength:%d\n", answerType, answerRdlength);

		if (answerType == MDNS_TYPE_PTR)
		{
			DEBUG_PRINT("PTR RECORD\n");

			partsCollected |= 0x01;
			// Read rdata
			fullHostName = readString(ss, answerRdlength);
			DEBUG_PRINT("FullHostName : %s\n", fullHostName);

			// we get service name and protocol here
			if (names.GetCount() > 1)
			{
				answerServiceName = names[0];
				answerProto = names[1];
			}
			if (answerServiceName[0] == '_')
				answerServiceName = answerServiceName.Mid(1);
			if (answerProto[0] == '_')
				answerProto = answerProto.Mid(1);
			DEBUG_PRINT("Service  : %s\n", answerServiceName);
			DEBUG_PRINT("Protocol : %s\n", answerProto);

		}
		else if (answerType == MDNS_TYPE_TXT)
		{
			DEBUG_PRINT("TXT RECORD\n");

			partsCollected |= 0x02;
			// Read rdata
			String packed = readString(ss, answerRdlength);
			int packedLen = packed.GetCount();

			// read text fields
			uint8_t const *p = (uint8_t const *)~packed;
			while (*p)
			{
				uint8_t len = *p++;
				packedLen--;
				if (len > packedLen)
					break;
				String s(p, len);
				p += len;
				packedLen -= len;
				int idx = s.Find('=');
				if (idx < 0)
					continue;
				answerTexts.Add(s.Left(idx), s.Mid(idx + 1));
			}

			DEBUG_PRINT("AnswerTexts : %s\n", answerTexts.ToString());
		}
		else if (answerType == MDNS_TYPE_SRV)
		{
			DEBUG_PRINT("SRV RECORD\n");

			partsCollected |= 0x04;
			/* uint16_t answerPrio = */
			read16MSB(ss);
			/* uint16_t answerWeight = */
			read16MSB(ss);
			answerPort = read16MSB(ss);
			DEBUG_PRINT("Poer : %d\n", answerPort);

			// Read hostname
			uint8_t len = read8(ss);
			if (len & 0xC0)   // Compressed pointer (not supported)
			{
				DEBUG_PRINT("Skipping compressed pointer\n");
				len = read8(ss);
			}
			else
			{
				answerHostName = readString(ss, len);
				DEBUG_PRINT("AnswerHostName : %s\n", answerHostName);
				if (answerRdlength - (6 + 1 + len) > 0)
					// Skip any remaining rdata
					ss.Skip(answerRdlength - (6 + 1 + len));
			}
		}

		else if (answerType == MDNS_TYPE_A)
		{
			DEBUG_PRINT("A RECORD\n");

			partsCollected |= 0x08;
			answerIP = readIP(ss);
			DEBUG_PRINT("AnswerIP : %s\n", answerIP.ToString());
		}
		else
		{
			DEBUG_PRINT("Ignoring unsupported record type %d\n", answerType);
			ss.Skip(answerRdlength);
		}

	}
	// when all parts has been collected we must check if the service
	// is one on which we're interested
	if ((partsCollected == 0x0F))
	{
		DEBUG_PRINT("All parts collected, check if we're interested\n");
		DEBUG_PRINT("ServiceName is %s:%s\n", answerServiceName, answerProto);
		DEBUG_PRINT("TTL : %d\n", (int)answerTTL);
		if (!IsInteresting(answerServiceName, answerProto))
			return true;
		DEBUG_PRINT("YEP, WE'RE INTERESTED\n");

		// ok, we're interested, get a pointer to relevant service
		// or create it if still not there
		MDNSService &service = FindAdd(answerServiceName, answerProto, answerIP);

		// populate service
		service.SetPort(answerPort);
		service.SetHostName(answerHostName);
		service.SetTTL(answerTTL);
		service.SetCreationTime(msecs());
		service.GetTexts() = pick(answerTexts);
		
// DUMP(_discoveredServices);

		// call handler
		WhenService();
	}
	else
	{
		DEBUG_PRINT("mhhhh...some part is missing\n");
	}

	return true;
}

// parse query packet
bool MDNSClass::ParseQuery(IPV4Address const &fromIP, uint16_t const *hdr, StringStream &ss)
{
	String hostName;
	uint8_t hostNameLen;

	bool serviceParsed = false;
	uint8_t serviceNameLen = 0;
	String serviceName;

	bool protoParsed = false;
	uint8_t protoNameLen = 0;
	String protoName;

	bool localParsed = false;
	uint8_t localNameLen = 0;
	String localName;

	// first string may be an host name, or a service name
	// if it's an host name must match our host name or instance name
	hostNameLen = read8(ss);
	hostName = readString(ss, hostNameLen);

	// if not starting with '_' it's an host name
	// and we shall just check if it matches our
	if (hostName[0] != '_')
	{
		if (hostName != _hostName && hostName != _instanceName)
		{
			DEBUG_PRINT("NOT FOR US : %s\n", hostName);
			DEBUG_PRINT("hostname   : %s\n", _hostName);
			DEBUG_PRINT("instance   : %s\n", _instanceName);
			return true;
		}
	}
	else
	{
		// it's a service name, check that its length fits
		hostNameLen--;
		if (hostNameLen > 32)
		{
			DEBUG_PRINT("BAD SERVICE NAME LENGTH %d\n", hostNameLen);
			return false;
		}
		serviceNameLen = hostNameLen;
		serviceName = hostName.Mid(1);
		hostName.Clear();
		hostNameLen = 0;
		serviceParsed = true;
	}

	// if still not parsed service, go on
	if (!serviceParsed)
	{
		serviceNameLen = read8(ss);
		if (serviceNameLen > 33)
		{
			DEBUG_PRINT("BAD SERVICE NAME LENGTH %d\n", serviceNameLen);
			return false;
		}

		// length fits, it may be a service name or 'local'
		// just fetch first byte
		char c = (char)read8(ss);
		if (c == '_')
		{
			// ok, it's a service, fetch it
			serviceNameLen--;
			serviceName = readString(ss, serviceNameLen);
			serviceParsed = true;
		}
		else
		{
			// not a service, it may be 'local'
			if (serviceNameLen != 5)
			{
				DEBUG_PRINT("ERROR, EXPECTING 'local', GOT LEN %d\n", serviceNameLen);
				return false;
			}
			localName.Clear();
			localName.Cat(c);
			localNameLen = serviceNameLen;
			localName << readString(ss, localNameLen - 1);
			if (localName != "local")
			{
				DEBUG_PRINT("ERROR, EXPECTING 'local', GOT %s\n", localName);
				return false;
			}

			// check for null terminator
			uint8_t tmp = read8(ss);
			if (tmp)
			{
				DEBUG_PRINT("EXPECTING NULL TERMINATOR, GOT %d\n", tmp);
				return false;
			}

			// ok, it was .local
			serviceNameLen = 0;
			serviceName.Clear();
			serviceParsed = true;

			protoNameLen = 0;
			protoName.Clear();
			protoParsed = true;

			localParsed = true;
		}
	}

	// if we've still to parse protocol, go on
	if (!protoParsed)
	{
		protoNameLen = read8(ss);
		if (protoNameLen > 32)
		{
			DEBUG_PRINT("BAD PROTOCOL NAME LENGTH %d\n", protoNameLen);
			return false;
		}
		protoName = readString(ss, protoNameLen);

		if (protoNameLen == 4 && protoName[0] == '_')
		{
			protoName = protoName.Mid(1);
			protoNameLen--;
			protoParsed = true;
		}
		else if (serviceName == "services" && protoName == "_dns-sd")
		{
			// we've been queried for all services, do it
			AdvertiseServices();
			return true;
		}
		else
		{
			DEBUG_PRINT("BAD PROTOCOL '%s'\n", protoName);
			return false;
		}
	}

	// local may still be missing, so go on
	if (!localParsed)
	{
		localNameLen = read8(ss);
		if (localNameLen != 5)
		{
			DEBUG_PRINT("EXPECTING 'local', GOT LEN %d\n", localNameLen);
			return false;
		}
		localName = readString(ss, localNameLen);

		if (localName != "local")
		{
			DEBUG_PRINT("EXPECTING 'local', GOT '%s'\n", localName);
			return false;
		}

		// read terminator
		uint8_t tmp = read8(ss);
		if (tmp)
		{
			DEBUG_PRINT("EXPECTING NULL TERMINATOR, GOT %d\n", tmp);
			return false;
		}
		localParsed = true;
	}

	// if we've been queried for a service/protocol, just give it
	int serviceIdx = -1;
	if (serviceNameLen > 0 && protoNameLen > 0)
	{
		serviceIdx = _providedServices.Find(String(serviceName) + ":" + protoName);
		if (serviceIdx < 0)
		{
			DEBUG_PRINT("SERVICE '%s' WITH PROTOCOL '%s' NOT AVAILABLE\n", serviceName, protoName);
			return false;
		}
	}
	else if (serviceNameLen > 0 || protoNameLen > 0)
	{
		DEBUG_PRINT("MISSING SERVICE NAME OR PROTOCOL\n");
		return false;
	}

	// RESPOND
	DEBUG_PRINT("RX: REQ, ID:%02x, Q:%02x, A:%02x , NS:%02x, ADD:%02x\n",
				hdr[0], hdr[2], hdr[3], hdr[4], hdr[5]
			   );

	uint16_t currentType;
	uint16_t currentClass;

	int numQuestions = hdr[2];
	if (numQuestions > 4) numQuestions = 4;
	uint16_t questions[4];
	int question = 0;

	while (numQuestions--)
	{
		currentType = read16MSB(ss);
		if (currentType & MDNS_NAME_REF) //new header handle it better!
			currentType = read16MSB(ss);

		currentClass = read16MSB(ss);
		if (currentClass & MDNS_CLASS_IN)
			questions[question++] = currentType;

		if (numQuestions > 0)
		{
			if (read16MSB(ss) != 0xC00C) //new question but for another host/service
			{
				numQuestions = 0;
			}
		}

	}
	uint8_t responseMask = 0;
	for (uint8_t i = 0; i < question; i++)
	{
		if (questions[i] == MDNS_TYPE_A)
			responseMask |= 0x1;
		else if (questions[i] == MDNS_TYPE_SRV)
			responseMask |= 0x3;
		else if (questions[i] == MDNS_TYPE_TXT)
			responseMask |= 0x4;
		else if (questions[i] == MDNS_TYPE_PTR)
			responseMask |= 0xF;
	}

	DEBUG_PRINT("ANSWERING SERVICE REQUEST\n");
	if (serviceIdx < 0)
		return false;

	Reply(responseMask, _providedServices[serviceIdx], fromIP, GetReplyIP(fromIP));
	return true;
}


// parse incoming packet
bool MDNSClass::ParsePacket(void)
{
	if (!udp.Available())
		return false;

	IPV4Address ip;
	uint16_t port;
	String data;
	udp.Receive(ip, port, data);
// DUMP(data);
	StringStream ss(data);

	DEBUG_PRINT("GOT PACKET WITH %d bytes from %s\n", data.GetCount(), ip.ToString());

	// read packet header
	uint16_t packetHeader[6];
	for (uint8_t i = 0; i < 6; i++)
		packetHeader[i] = read16MSB(ss);

	// check if is an answer packet
	if ((packetHeader[1] & 0x8000) != 0)
		return ParseAnswer(packetHeader, ss);

	// a query packet
	return ParseQuery(ip, packetHeader, ss);
}

// listen for a service
MDNSClass &MDNSClass::ListenFor(String const &name, String const &proto)
{
	_interestingServices.FindAdd(name + ":" + proto);
	return *this;
}

// un-listen for a service
MDNSClass &MDNSClass::NoListenFor(String const &name, String const &proto)
{
	_interestingServices.RemoveKey(name + ":" + proto);
	return *this;
}

// enable/disable listening for services
MDNSClass &MDNSClass::Listen(bool l)
{
	if(l)
	{
		udp.Listen();
		udp.IGMPJoin(MDNS_ADDR);
		Query();
	}
	else
		udp.Stop();
	return *this;
}

// get discovered services
Array<MDNSService> MDNSClass::GetDiscoveredServices(String const &name, String const &proto)
{
	Array<MDNSService> res;

	String key = name + ":" + proto;
	for (int i = 0; i < _discoveredServices.GetCount(); i++)
	{
		if (_discoveredServices.GetKey(i) != key)
			continue;
		res.Add(_discoveredServices[i]);
	}

	return res;
}

// add or replace a provided service
MDNSClass &MDNSClass::AddService(MDNSService const &s)
{
	String name = s.GetName();
	String proto = s.GetProtocol();
	String key = name + ":" + proto;
	_providedServices.GetAdd(key) = s;
	return *this;
}

// remove a provided service
MDNSClass &MDNSClass::RemoveService(String const &name, String const &proto)
{
	String key = name + ":" + proto;
	_providedServices.RemoveKey(key);
	return *this;
}

// advertise provided services
MDNSClass &MDNSClass::AdvertiseServices(void)
{
	// advertise for all hosted IP besides localhost one
	VectorMap<String, IPV4Address> addrs = GetLocalIPV4Addresses();
	for (int iAddr = 0; iAddr < addrs.GetCount(); iAddr++)
	{
		IPV4Address const &addr = addrs[iAddr];
		if (addr[0] == 127)
			continue;
		for (int iService = 0; iService < _providedServices.GetCount(); iService++)
			Reply(MDNS_ANSWERS_ALL, _providedServices[iService], addr, MDNS_ADDR);
	}
	return *this;
}

// run a query for interesting services
MDNSClass &MDNSClass::Query(void)
{
	// on query, we erase already fetched services
	// and refresh them all
	_discoveredServices.Clear();
	
	// query for all interesting services
	for(int iService = 0; iService < _interestingServices.GetCount(); iService++)
	{
		String key = _interestingServices[iService];
		int idx = key.Find(':');
		if(idx < 0)
			continue;
		String name = "_" + key.Left(idx);
		String proto = "_" + key.Mid(idx + 1);

		DEBUG_PRINT("queryService %s:%s\n", name, proto);
		
		// Only supports sending one PTR query
		uint8_t questionCount = 1;

		// start request record
		String rec;
		
		// build the header
		uint8_t head[12] =
		{
			0x00, 0x00,				//ID = 0
			0x00, 0x00,				//Flags = response + authoritative answer
			0x00, questionCount,	//Question count
			0x00, 0x00,				//Answer count
			0x00, 0x00,				//Name server records
			0x00, 0x00				//Additional records
		};
		rec.Cat(head, 12);

		// only supports sending one PTR query
		// send the Name field (eg. "_http._tcp.local")
		rec.Cat((uint8_t)name.GetCount());			// lenght of "_" + service
		rec << name;								// "_" + service
		rec.Cat((uint8_t)proto.GetCount());			// lenght of "_" + proto
		rec << proto;								// "_" + proto
		rec.Cat((uint8_t)5);						// lenght of "local"
		rec << "local";								// "local"
		rec.Cat(0);									// terminator
		
		//Send the type and class
		uint8_t ptrAttrs[4] =
		{
			0x00, 0x0c,		//PTR record query
			0x00, 0x01		//Class IN
		};
		rec.Cat(ptrAttrs, 4);
		
		// send the packet to the multicast address/port
		udp.SendTo(MDNS_ADDR, MDNS_PORT, rec);
	}
	
	return *this;
}

// loop function -- to be called periodically, if not in MT mode
void MDNSClass::Loop(void)
{
	bool someExpired = false;
	
	// check for expired records and remove them
	for (int i = _discoveredServices.GetCount() - 1; i >= 0; i--)
	{
		MDNSService const &s = _discoveredServices[i];
		if ((uint32_t)msecs() > s.GetTTL() * 1000 + s.GetCreationTime())
		{
			someExpired = true;
			_discoveredServices.Remove(i);
		}
	}

	// if not listening, return
	if (!IsListening())
		return;
	
	// if some packet has expired, run a query
	if(someExpired)
		Query();

	// if some packet arrived, parse it
	if (udp.Available())
	{
		DEBUG_PRINT("UDP AVAIL\n");
		ParsePacket();
	}
}


MDNSClass &__GetMDNS(void)
{
	static MDNSClass mdns;
	return mdns;
}
