#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2022 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

from collections          import defaultdict
from cve_manager.defines  import VUL_MARKERS
from cve_manager.db       import DB, Placeholders
from cve_manager.software import DistinctSoftware
from cpe_map.defines      import M_FIXES
from cpe_map.multi_branch import Run
from cpe_map.common       import MatchesStr, CompatibleGroups
from cpe_map.name_conv    import NAME_GROUP

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get 'Fixes' records of packages with given names

def QueryFixes(db, branch, data_source, package_names):

	table = f'{branch}_src'
	fixes_col = f'{VUL_MARKERS.get(data_source, "")[:-1].lower()}_fixes'
	cols = ['name', fixes_col]
	cond = f"{fixes_col} <> '' AND {fixes_col} is not NULL AND " \
		f"name in ({Placeholders(len(package_names))})"
	data = [tuple(package_names)]

	cursor = db.Select(cols, table, cond, data=data)
	if not cursor:
		return {}

	res = {}
	for package_name, fixes in cursor:
		group = package_names[package_name][NAME_GROUP]
		res[package_name] = (group, set(fixes.split()))

	return res

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get product names of corresponding vulnerabilities that stated as fixed

def QueryProductNamesOfFixedVul(db, target, fixes, selection):

	if len(fixes) < 1:
		return {}

	ids_of_fixed_vul = defaultdict(set)
	for _, vul_ids in fixes.values():
		for vul_id in vul_ids:
			year, num = vul_id.split('-')
			ids_of_fixed_vul[year].add(num)

	table = f'{target}_vul_import'
	cols  = ['vul_year', 'vul_num', 'software_list']
	cond  = 'OR '.join(f"(vul_year = {year} AND vul_num = '{num}') "
		for year, nums in ids_of_fixed_vul.items() for num in nums)

	cursor = db.Select(cols, table, cond)
	if not cursor:
		return {}

	software_of_vul_ids = defaultdict(set)
	for vul_year, vul_num, software_listing in cursor:
		vul_id = f'{vul_year}-{vul_num}'
		for software in DistinctSoftware(software_listing):
			product = software.Product()
			if product == '-':
				continue
			vendor = software.Vendor()
			group = selection.get(vendor, {}).get(product, {}).get(NAME_GROUP)
			if group != None:
				software_of_vul_ids[vul_id].add((vendor, product, group))

	return software_of_vul_ids

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get names of products that correspond to 'Fixes' records of package changelogs

def GetMatches(package_names, setup, send_pipe, err_event):

	# Connecting to the CVE database
	db = DB(silent=True)
	if not db.Connect(setup.mysql_config):
		err_event.set()
		return

	# Getting 'Fixes' records
	fixes = QueryFixes(db, setup.branch, setup.data_source, package_names)

	# Getting product names of corresp vulnerabilities that stated as fixed
	products_of_vul_ids = QueryProductNamesOfFixedVul(db, setup.data_source,
		fixes, setup.products)

	matches = {}

	# To be matched with a package a product should be mentioned in no less
	# then 75% of software listings of vulnerabilities of fixes-records
	# of this package
	ACCEPTABLE_PERCENTAGE = 0.75

	# For all package names that have 'fixes' records
	for package_name, package_data in fixes.items():
		package_group, ids_of_fixed_vuls = package_data
		if err_event.is_set():
			break
		fixes_count = defaultdict(lambda: defaultdict(int))
		fixes_count_flat = defaultdict(int)
		for id_of_fixed_vul in ids_of_fixed_vuls:
			# Checking if there are products corresponded with this id
			products_of_vul_id = products_of_vul_ids.get(id_of_fixed_vul, {})
			for vendor, product, product_group in products_of_vul_id:
				if CompatibleGroups(package_group, product_group):
					fixes_count[vendor][product] += 1
					fixes_count_flat[product] += 1
		# Filtering product names that are mentioned not as much as needed
		n_fixes_records = len(ids_of_fixed_vuls)
		filtered_fixes_count = defaultdict(dict)
		for vendor, products in fixes_count.items():
			for product, n in products.items():
				if fixes_count_flat[product] > (n_fixes_records * ACCEPTABLE_PERCENTAGE):
					filtered_fixes_count[vendor][product] = n
		if filtered_fixes_count:
			match = ''
			for vendor, count_for_products in filtered_fixes_count.items():
				if match:
					match += '  '
				match += MatchesStr(count_for_products, vendor)
			if match:
				matches[package_name] = match

	send_pipe.send(matches)

	return

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	success = Run(M_FIXES, GetMatches)

	exit(0 if success else 1)
