#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2024 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

from collections          import defaultdict
from re                   import search as re_search
from cve_manager.db       import DB, Placeholders
from cve_manager.defines  import NVD_DATA_SRC, FSTEC_DATA_SRC
from cve_manager.url      import ParseURL, JoinURL, URL_NETLOC, URL_PATH, \
	URL_QUERY, URL_NETLOC_HASH, URL_PATH_HASH, URL_QUERY_HASH
from cpe_map.defines      import M_URL
from cpe_map.multi_branch import Run

UNIQUE_NETLOC = max(URL_NETLOC, URL_PATH, URL_QUERY, URL_NETLOC_HASH,
	URL_PATH_HASH, URL_QUERY_HASH) + 1
UNIQUE_PATH   = UNIQUE_NETLOC + 1
FROM_CPE_DICT = UNIQUE_PATH + 1
GOT_SUBDOMAIN = FROM_CPE_DICT + 1

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get URL-addresses of packages with given names
# {package: (netloc_of_url, path_of_url)}
# or get URL-addresses of all the products
# {product: [(netloc_of_url1, path_of_url1), (netloc_of_url2, path_of_url2), ..]}

def QueryURLs(mysql_config, branch=None, data_source=None, selection=[]):

	if all([branch, data_source]) or not any([branch, data_source]):
		return None

	query_urls_of_packages = bool(branch)

	db = DB()
	if not db.Connect(mysql_config):
		return None

	if query_urls_of_packages:
		table = f'{branch}_src'
		query = f"SELECT name, url, relatives FROM {table} WHERE url <> '' "
		if selection:
			query += f'AND name in ({Placeholders(len(selection))})'
		tables = (table,)
		data = [tuple(selection)]
	elif data_source == NVD_DATA_SRC:
		table_a = 'cpe_import'
		table_b = f'{NVD_DATA_SRC}_products'
		query = ("("
			f"SELECT vendor, product, refs, '' FROM {table_a} a "
			f"WHERE refs <> '' AND EXISTS (SELECT 1 FROM {table_b} b "
				"WHERE a.product = b.product)"
			") UNION ("
			f"SELECT DISTINCT b.vendor, b.product, '', b.refs FROM {table_b} b "
			f"INNER JOIN {table_a} a "
				"ON a.vendor = b.vendor AND a.product = b.product "
			"WHERE a.refs = '' AND b.refs <> ''"
			") UNION ("
			f"SELECT vendor, product, '', refs FROM {table_b} x "
			f"WHERE NOT EXISTS (SELECT 1 FROM {table_a} y "
				"WHERE x.vendor = y.vendor AND x.product = y.product) "
				"AND refs <> ''"
			")")
		tables = (table_a, table_b)
		data = []
	elif data_source == FSTEC_DATA_SRC:
		table = f'{FSTEC_DATA_SRC}_products'
		query = f"SELECT vendor, product, refs, '' FROM {table} " \
			f"WHERE refs <> ''"
		tables = (table,)
		data = []
	else:
		return None
	cursor = db.Cursor(query, tables=tables, data=data)
	if not cursor:
		return None

	names_of_paths_of_netlocs = defaultdict(lambda: defaultdict(set))
	names_of_netlocs = defaultdict(set)

	# Write given name and a netloc and a path of a given URL into helper
	# data structures
	def UpdateSupplimentaryData(name, url):
		netloc = url[URL_NETLOC]
		path   = url[URL_PATH]
		names_of_paths_of_netlocs[netloc][path].add(name)
		names_of_netlocs[netloc].add(name)
		subdomain_end_pos = netloc.rfind('.', 0, netloc.rfind('.'))
		if subdomain_end_pos > 0:
			domain = netloc[subdomain_end_pos + 1:]
			names_of_netlocs[domain].add(name)
	# - nested func

	# Check if netloc and path of a given URL are unique
	def UniqueNetlocAndPath(url, relatives={}):
		unique_netloc = \
			BasicUniquenessCheck(names_of_netlocs[url[URL_NETLOC]], relatives)
		# This netloc can have multiple paths for multiple names, and
		# definitely has path(s) for a name that is currently processed
		names_of_paths = names_of_paths_of_netlocs[url[URL_NETLOC]]
		# Path is not unique if various names have it
		this_path = url[URL_PATH]
		if not BasicUniquenessCheck(names_of_paths[this_path], relatives):
			return unique_netloc, False
		# Path is unique if it's not included in other paths of this netloc
		other_paths = set(names_of_paths.keys()) - set(this_path)
		unique_path = all(len(other_path) <= len(this_path) or
			not other_path.startswith(this_path) for other_path in other_paths)
		return unique_netloc, unique_path
	# - nested func

	# Domain name or path can't be unique if there are other exactly like it
	# or if it's shared among names that aren't all belong to a single set
	# of relatives
	def BasicUniquenessCheck(names_of_this_netloc_or_path, relatives):
		if len(names_of_this_netloc_or_path) == 1:
			return True
		related_names = set()
		for name in names_of_this_netloc_or_path:
			related_names |= relatives.get(name, set())
		return not (names_of_this_netloc_or_path - related_names)
	# - nested func

	if query_urls_of_packages:
		res = {}
		relatives = defaultdict(set)
		for name, url, rel in cursor:
			parsed_url = ParseURL(url)
			if not parsed_url:
				continue
			res[name]  = parsed_url
			relatives[name] = set(rel.split())
			UpdateSupplimentaryData(name, parsed_url)
		for name, url in res.items():
			res[name][UNIQUE_NETLOC], res[name][UNIQUE_PATH] = \
				UniqueNetlocAndPath(url, relatives)
			res[name][GOT_SUBDOMAIN] = url[URL_NETLOC].count('.') >= 2
	else:
		pre_res = defaultdict(dict)
		for vendor, product, refs_main, refs_plus in cursor:
			for refs, from_cpe_dict in ((refs_main, True), (refs_plus, False)):
				for url in set(refs.split()):
					parsed_url = ParseURL(url)
					if not parsed_url:
						continue
					parsed_url[FROM_CPE_DICT] = from_cpe_dict
					reconstr_url = JoinURL(
						parsed_url[URL_NETLOC],
						parsed_url[URL_PATH],
						parsed_url[URL_QUERY],
						)
					if not pre_res[(vendor, product)].get(reconstr_url):
						pre_res[(vendor, product)][reconstr_url] = parsed_url
						UpdateSupplimentaryData(product, parsed_url)
		res = defaultdict(list)
		for vendor_and_product, urls in pre_res.items():
			_, product = vendor_and_product
			for url in urls.values():
				url[UNIQUE_NETLOC], url[UNIQUE_PATH] = \
					UniqueNetlocAndPath(url)
				url[GOT_SUBDOMAIN] = url[URL_NETLOC].count('.') >= 2
				res[vendor_and_product].append(url)

	return res

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Check if a first given path is a subpath of a second one

def FirstInSecond(x_path, y_path):

	return len(y_path) > len(x_path) and \
		y_path.startswith(x_path) and \
		y_path[len(x_path)] == '/'

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Compare given URL-addresses

def CompareURLs(package_url, product_url):

	if package_url[UNIQUE_NETLOC] and product_url[UNIQUE_NETLOC]:
		length_d = len(package_url[URL_NETLOC]) - len(product_url[URL_NETLOC])
		if length_d == 0:
			if package_url[URL_NETLOC_HASH] != product_url[URL_NETLOC_HASH]:
				return False
		elif length_d < 0:
			if not product_url[URL_NETLOC].endswith('.' + package_url[URL_NETLOC]):
				return False
			product_url_subdomain = product_url[URL_NETLOC][:-length_d - 1]
			return product_url_subdomain in package_url[URL_PATH] or \
				product_url_subdomain in package_url[URL_QUERY]
		else:
			if not package_url[URL_NETLOC].endswith('.' + product_url[URL_NETLOC]):
				return False
			package_url_subdomain = package_url[URL_NETLOC][:length_d - 1]
			return package_url_subdomain in product_url[URL_PATH] or \
				package_url_subdomain in product_url[URL_QUERY]
		if not package_url[URL_PATH] or not product_url[URL_PATH]:
			return True
	else:
		if package_url[URL_NETLOC_HASH] != product_url[URL_NETLOC_HASH]:
			return False
		if not package_url[URL_PATH] and not product_url[URL_PATH] and \
				package_url[GOT_SUBDOMAIN] and product_url[GOT_SUBDOMAIN]:
			return True
		if (not package_url[URL_PATH] and not package_url[URL_QUERY]) or \
				(not product_url[URL_PATH] and not product_url[URL_QUERY]):
			return False

	if not (package_url[UNIQUE_PATH] and
				FirstInSecond(package_url[URL_PATH], product_url[URL_PATH])) and \
			not (product_url[UNIQUE_PATH] and
				FirstInSecond(product_url[URL_PATH], package_url[URL_PATH])) and \
			package_url[URL_PATH_HASH] != product_url[URL_PATH_HASH]:
		return False

	return (not package_url[URL_QUERY] or not product_url[URL_QUERY]) or \
		package_url[URL_QUERY_HASH] == product_url[URL_QUERY_HASH]

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Find url matches, send results to a given pipe

def GetMatches(package_names, setup, send_pipe, err_event):

	matches = defaultdict(dict)

	for package in package_names:
		matches_for_this_package = defaultdict(list)
		if err_event.is_set():
			break
		package_url = setup.urls_of_packages.get(package)
		if not package_url:
			continue
		for vendor_and_product, product_urls in setup.urls_of_products.items():
			for product_url in product_urls:
				if not product_url or not CompareURLs(package_url, product_url):
					continue
				k = 'grade_a' if product_url[FROM_CPE_DICT] else 'grade_b'
				matches_for_this_package[k].append(":".join(vendor_and_product))
				break
		for k, v in matches_for_this_package.items():
			matches[package][k] = '  '.join(v)

	send_pipe.send(matches)
	send_pipe.close()

	return

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	success = Run(M_URL, GetMatches, [QueryURLs])

	exit(0 if success else 1)
