#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2022 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

from collections          import defaultdict
from cve_manager.db       import DB, Placeholders
from cve_manager.defines  import NVD_DATA_SRC, FSTEC_DATA_SRC
from cve_manager.url      import ParseURL
from cpe_map.defines      import M_URL
from cpe_map.multi_branch import Run

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get URL-addresses of packages with given names
# {package: (netloc_of_url, path_of_url)}
# or get URL-addresses of all the products
# {product: [(netloc_of_url1, path_of_url1), (netloc_of_url2, path_of_url2), ..]}

def QueryURLs(mysql_config, branch=None, data_source=None, selection=[]):

	if all([branch, data_source]) or not any([branch, data_source]):
		return None

	query_urls_of_packages = bool(branch)

	db = DB()
	if not db.Connect(mysql_config):
		return None

	if query_urls_of_packages:
		table = f'{branch}_src'
		query = f"SELECT name, url, relatives FROM {table} WHERE url <> '' "
		if selection:
			query += f'AND name in ({Placeholders(len(selection))})'
		tables = (table,)
		data = [tuple(selection)]
	elif data_source == NVD_DATA_SRC:
		table_a = 'cpe_import'
		table_b = f'{NVD_DATA_SRC}_products'
		query = ("("
			f"SELECT vendor, product, refs, '' FROM {table_a} a "
			f"WHERE refs <> '' AND EXISTS (SELECT 1 FROM {table_b} b "
				"WHERE a.product = b.product)"
			") UNION ("
			f"SELECT DISTINCT b.vendor, b.product, '', b.refs FROM {table_b} b "
			f"INNER JOIN {table_a} a "
				"ON a.vendor = b.vendor AND a.product = b.product "
			"WHERE a.refs = '' AND b.refs <> ''"
			") UNION ("
			f"SELECT vendor, product, '', refs FROM {table_b} x "
			f"WHERE NOT EXISTS (SELECT 1 FROM {table_a} y "
				"WHERE x.vendor = y.vendor AND x.product = y.product) "
				"AND refs <> ''"
			")")
		tables = (table_a, table_b)
		data = []
	elif data_source == FSTEC_DATA_SRC:
		table = f'{FSTEC_DATA_SRC}_products'
		query = f"SELECT vendor, product, refs, '' FROM {table} " \
			f"WHERE refs <> ''"
		tables = (table,)
		data = []
	else:
		return None
	cursor = db.Cursor(query, tables=tables, data=data)
	if not cursor:
		return None

	names_of_paths_of_netlocs = defaultdict(lambda: defaultdict(set))
	names_of_netlocs = defaultdict(set)

	# Write given name and a netloc and a path of a given URL into helper
	# data structures
	def UpdateSupplimentaryData(name, url):
		names_of_paths_of_netlocs[url['netloc']][url['path']].add(name)
		names_of_netlocs[url['netloc']].add(name)
	# - nested func

	# Check if netloc and path of a given URL are unique
	def UniqueNetlocAndPath(name, url, relatives=None):
		unique_netloc = \
			BasicUniquenessCheck(name, names_of_netlocs[url['netloc']], relatives)
		# This netloc can have multiple paths for multiple names, and
		# definitely has path(s) for a name that is currently processed
		names_of_paths = names_of_paths_of_netlocs[url['netloc']]
		# Path is not unique if various names have it
		this_path = url['path']
		if not BasicUniquenessCheck(name, names_of_paths[this_path], relatives):
			return unique_netloc, False
		# Path is unique if it's not included in other paths of this netloc
		other_paths = set(names_of_paths.keys()) - set(this_path)
		unique_path = all(len(other_path) <= len(this_path) or
			not other_path.startswith(this_path) for other_path in other_paths)
		return unique_netloc, unique_path
	# - nested func

	# Domain name or path can't be unique if there are other exactly like it
	# or if it's shared among names that aren't all belong to a single set
	# of relatives
	def BasicUniquenessCheck(name, names_of_this_netloc_or_path, relatives):
		if len(names_of_this_netloc_or_path) == 1:
			return True
		if relatives:
			return not (names_of_this_netloc_or_path - relatives[name] - {name})
		return False
	# - nested func

	if query_urls_of_packages:
		res = {}
		relatives = defaultdict(set)
		for name, url, rel in cursor:
			parsed_url = ParseURL(url)
			res[name]  = parsed_url
			relatives[name] = set(rel.split())
			UpdateSupplimentaryData(name, parsed_url)
		for name, url in res.items():
			res[name]['unique_netloc'], res[name]['unique_path'] = \
				UniqueNetlocAndPath(name, url, relatives)
	else:
		res = defaultdict(list)
		for vendor, product, refs_main, refs_plus in cursor:
			for refs, from_cpe_dict in ((refs_main, True), (refs_plus, False)):
				for url in set(refs.split()):
					parsed_url = ParseURL(url)
					parsed_url['from_cpe_dict'] = from_cpe_dict
					res[(vendor, product)].append(parsed_url)
					UpdateSupplimentaryData(product, parsed_url)
		for vendor_and_product, urls in res.items():
			_, product = vendor_and_product
			for url in urls:
				url['unique_netloc'], url['unique_path'] = \
					UniqueNetlocAndPath(product, url)

	return res

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Compare given URL-addresses

def CompareURLs(package_url, product_url):

	if not package_url['netloc'] or \
			package_url['netloc'] != product_url['netloc'] or \
			package_url['query'] != product_url['query']:
		return False

	if not package_url['path'] or not product_url['path']:
		return (
			package_url['unique_netloc'] and
			product_url['unique_netloc']
			) or (
			not package_url['path'] and
			not product_url['path'] and
			package_url['netloc'].count('.') >= 2 and
			product_url['netloc'].count('.') >= 2
			)

	# Check if a first given path is a subpath of a second one
	def FirstInSecond(x_path, y_path):
		return len(y_path) > len(x_path) and \
			y_path.startswith(x_path) and \
			y_path[len(x_path)] == '/'
	# - nested func

	return (
		(package_url['unique_path'] and
			FirstInSecond(package_url['path'], product_url['path'])) or
		(product_url['unique_path'] and
			FirstInSecond(product_url['path'], package_url['path'])) or
		package_url['path'] == product_url['path']
		)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Find url matches, send results to a given pipe

def GetMatches(package_names, setup, send_pipe, err_event):

	matches = defaultdict(dict)

	for package in package_names:
		matches_for_this_package = defaultdict(list)
		if err_event.is_set():
			break
		package_url = setup.urls_of_packages.get(package)
		if not package_url:
			continue
		for vendor_and_product, product_urls in setup.urls_of_products.items():
			for product_url in product_urls:
				if not product_url or not CompareURLs(package_url, product_url):
					continue
				k = 'grade_a' if product_url['from_cpe_dict'] else 'grade_b'
				matches_for_this_package[k].append(":".join(vendor_and_product))
				break
		for k, v in matches_for_this_package.items():
			matches[package][k] = '  '.join(v)

	send_pipe.send(matches)

	return

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	success = Run(M_URL, GetMatches, [QueryURLs])

	exit(0 if success else 1)
