#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2024 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

from collections          import defaultdict
from cve_manager.common   import Init as InitCommon, GetDef
from cve_manager.conf     import DB_CON_SEC
from cve_manager.db       import DB
from cve_manager.defines  import FSTEC_DATA_SRC, M_SPECIAL
from cve_manager.desc     import FSTEC_MAP_SPECIAL
from cve_manager.software import Software
from cpe_map.common       import MatchesStr, NewArgParser
from cpe_map.init         import Init as InitMappingHelpers

MIN_SCORE = 0.1

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments, initialising helper objects and reading the conf file

args = NewArgParser(ptype='b', description=FSTEC_MAP_SPECIAL).parse_args()
printer, conf = InitCommon(args)
mysql_config, = conf.Get([DB_CON_SEC])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get a dict of sets of FSTEC products of every FSTEC entry {bdu_id: {products}}

def GetFSTECProducts():

	printer.LineBegin('Getting the vulnerability IDs and product names '
		'of every FSTEC entry')

	db = DB(printer)
	if not db.Connect(mysql_config):
		return None

	table  = 'fstec_vul_import'
	fields = ['vul_year', 'vul_num', 'software_list']
	cursor = db.Select(fields, table)
	if cursor == None:
		return None

	ret = defaultdict(set)
	for bdu_year, bdu_num, software_list in cursor:
		bdu_id = f'{bdu_year}-{bdu_num}'
		for software in [Software(s) for s in software_list.split()]:
			ret[bdu_id].add(software.Product())

	printer.Success()

	return ret

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get a dict of sets of CVE IDs of every FSTEC entry {bdu_id: {cve_ids}}

def GetCVEIDsOfFSTECEntries():

	printer.LineBegin('Getting the CVE IDs of the FSTEC entries')

	db = DB(printer)
	if not db.Connect(mysql_config):
		return None

	table  = 'fstec_vul_import'
	fields = ['vul_year', 'vul_num', 'cve_ids']
	cond   = "cve_ids <> ''"
	cursor = db.Select(fields, table, cond=cond)
	if cursor == None:
		return None

	ret = {}
	for bdu_year, bdu_num, cve_ids in cursor:
		ret[f'{bdu_year}-{bdu_num}'] = set(cve_ids.split())

	printer.Success()

	return ret

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get a dict of sets of package names of every NVD issue {cve_id: {packages}}

def GetCVEIDsOfPackages(branch):

	printer.LineBegin(f'Getting the names of {branch} packages of every NVD '
		'issue')

	db = DB(printer)
	if not db.Connect(mysql_config):
		return None

	tables = [f'{branch}_nvd_issues', f'{branch}_src']
	query  = f'SELECT a.vul_year, a.vul_num, b.name FROM {tables[0]} a ' \
		f'INNER JOIN {tables[1]} b ON a.package_id = b.id'
	cursor = db.Cursor(query, tables=tables)
	if cursor == None:
		return None

	ret = defaultdict(set)
	for cve_year, cve_num, package_name in cursor:
		ret[f'{cve_year}-{cve_num}'].add(package_name)

	printer.Success()

	return ret

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Use a dict of packages related to BDU IDs and a dict of FSTEC products
# related to BDU IDs to form a dict of package names related to FSTEC products
# {package_name: fstec_products}

def MapFSTECProductsToPackages(bdu_to_packages, bdu_to_fstec_products):

	printer.LineBegin('Searching for the FSTEC products related to '
		'the packages via CVE IDs')

	packages_to_fstec_products = defaultdict(set)
	for bdu_id, packages in bdu_to_packages.items():
		fstec_products = bdu_to_fstec_products.get(bdu_id)
		if fstec_products:
			for package in packages:
				if packages_to_fstec_products[package]:
					packages_to_fstec_products[package] &= fstec_products
				else:
					packages_to_fstec_products[package] = fstec_products

	# Filter out empty sets
	packages_to_fstec_products = \
		{k: v for k, v in packages_to_fstec_products.items() if v}

	msg = f'{len(packages_to_fstec_products)} package(s) related to some ' \
		'of the FSTEC products'
	if packages_to_fstec_products:
		printer.Success(msg)
	else:
		printer.Warn(msg)

	return packages_to_fstec_products

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	# Forming a list of branches that will be processed
	branches, err = \
		conf.SelectBranches(args.branches if 'all' not in args.branches else [])
	if err or not branches:
		printer.Err(err if err else 'Can\'t form a list of branch names')
		exit(1)

	# Creating an object that simplifies sending data to the database
	_, mediator = InitMappingHelpers(M_SPECIAL, args, no_speaker=True)

	# BDU IDs --> FSTEC products
	bdu_to_fstec_products = GetFSTECProducts()
	if bdu_to_fstec_products == None:
		exit(1)

	# BDU IDs --> CVE IDs
	bdu_to_cve = GetCVEIDsOfFSTECEntries()
	if bdu_to_cve == None:
		exit(1)

	# Function used to calc a score for this type of matching
	def LinearFuncWithConstraint(x):
		y = 1.1 - 0.1 * x if 1 <= x <= 10 else 0.0
		return round(y, 2)
		# - nested func

	# This is a multibranch type of matching
	for branch in branches:

		# CVE IDs --> packages
		cve_to_packages = GetCVEIDsOfPackages(branch)
		if cve_to_packages == None:
			exit(1)

		# BDU IDs --> (CVE IDs) --> packages
		bdu_to_packages = defaultdict(set)
		for bdu_id, cve_ids in bdu_to_cve.items():
			for cve_id in cve_ids:
				packages = cve_to_packages.get(cve_id)
				if packages:
					bdu_to_packages[bdu_id] |= packages

		# Packages --> (BDU IDs) --> FSTEC products
		packages_to_fstec_products = MapFSTECProductsToPackages(
			bdu_to_packages, bdu_to_fstec_products)
		if not packages_to_fstec_products:
			continue

		# Convert to a format used for matches
		matches = {}
		for package, fstec_products in packages_to_fstec_products.items():
			score = LinearFuncWithConstraint(len(fstec_products))
			if score < MIN_SCORE:
				continue
			scores = {product: score for product in fstec_products}
			matches[package] = MatchesStr(scores)

		# Updating the src_to_fstec table with the results
		if not mediator.SendMatches(
				FSTEC_DATA_SRC, matches, f'{branch}_{M_SPECIAL}'):
			exit(1)

	exit(0)
