#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2022 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

from sys                       import argv
from cve_manager.common        import GetDef, Init
from cve_manager.conf          import DB_CON_SEC, COMMON_SEC, LOCAL_SYS
from cve_manager.defines       import FSTEC_DATA_SRC
from cve_manager.ignored_pairs import IgnoredPairs
from cpe_map.common            import NewArgParser
from cpe_map_choice.mediator   import Mediator
from cpe_map_choice.solver     import Solver

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Forming a set of arguments

argparser = NewArgParser(ptype='bp')
argparser.add_argument(
	'-v', '--verbose',
	action='store_true',
	help='Enable verbose output'
	)

if len(argv) < 2:
	argparser.print_help()
	exit(1)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	# Parsing arguments
	args = argparser.parse_args()

	# Initializing basic helper objects
	printer, conf = Init(args)

	# Getting configuration params
	mysql_config, common_params = conf.Get([DB_CON_SEC, COMMON_SEC])

	# Initializing higher level helper objects
	mediator = Mediator(mysql_config, printer)
	solver = Solver(printer, args.verbose)

	# Forming a list of branches that will be processed
	branches, err = \
		conf.SelectBranches(args.branches if 'all' not in args.branches else [])
	if err or not branches:
		printer.Err(err if err else 'Can\'t form a list of branch names')
		exit(1)

	# Get src & bin branch map
	def BranchMap(branch, data_source):
		branch_map = mediator.QueryBranchMap(branch, data_source, args.packages)
		if not branch_map and data_source != FSTEC_DATA_SRC:
			exit(1)
		return branch_map

	ignored_pairs = {}

	# Get ignored pairs by using a dedicated object
	def GetIgnoredPairs(data_source):
		global ignored_pairs
		ignored = ignored_pairs.get(data_source)
		if not ignored:
			printer.LineBegin(f'Loading {data_source.upper()} ignore list')
			ignored, msg = \
				IgnoredPairs(GetDef('CONF_DIR', args.debug), data_source).Read()
			if ignored == None:
				printer.Err(msg)
				exit(1)
			printer.Success(msg)
			ignored_pairs[data_source] = ignored
		return ignored

	products_of_all_data_sources = {}

	# Get a dict of supplementary products of a given data source and a set of
	# names that are available for supplementary products and that are not
	# available for non-supplementary products
	def SupplementaryProducts(data_source):
		global products_of_all_data_sources
		products = products_of_all_data_sources.get(data_source, (None, None))
		if not any(products):
			products = mediator.QueryNotCorruptedAndCorruptedProducts(data_source)
			if not any(products):
				if data_source == FSTEC_DATA_SRC:
					return {}, {}
				exit(1)
			products_of_all_data_sources[data_source] = products
		non_supplementary_products, supplementary_products = products
		# Forming a set of product names of non-supplementary products
		non_supplementary_product_names = {product
			for products in non_supplementary_products.values()
				for product in products}
		# Forming a set of product names of supplementary products that are not
		# present in a set of product names of non-supplementary products
		supplementary_product_names = {product
			for products in supplementary_products.values()
				for product in products
					if product not in non_supplementary_product_names}
		return supplementary_products, supplementary_product_names

	# Get a dict of related packages {<package_name>: [<related_package_names>]}
	# of a specified branch
	def RelatedPackages(branch):
		related_packages = mediator.QueryRelatedPackages(branch, args.packages)
		if related_packages == None:
			exit(1)
		return related_packages

	related_products_of_all_data_sources = {}

	# Get a dict of related products {<product_name>: [<related_product_names>]}
	# of a specified data_source
	def RelatedProducts(data_source):
		global related_products_of_all_data_sources
		related_products = related_products_of_all_data_sources.get(data_source)
		if not related_products:
			related_products = mediator.QueryRelatedProducts(data_source)
			if related_products == None:
				exit(1)
			related_products_of_all_data_sources[data_source] = related_products
		return related_products

	current_versions_of_packages_of_all_branches = {}

	# Get current versions of packages
	def CurrentVersionsOfPackages(branch):
		global current_versions_of_packages_of_all_branches
		current_versions = \
			current_versions_of_packages_of_all_branches.get(branch)
		if current_versions:
			return current_versions
		current_versions = mediator.QueryCurrentPackageVersions(branch)
		if not current_versions:
			exit(1)
		current_versions_of_packages_of_all_branches[branch] = current_versions
		return current_versions

	time_lines_of_products_of_all_data_sources = {}

	# Get timelines of products
	def TimeLinesOfProducts(data_source):
		global time_lines_of_products_of_all_data_sources
		time_lines = time_lines_of_products_of_all_data_sources.get(data_source)
		if time_lines:
			return time_lines
		time_lines = mediator.QueryTimeLines(data_source=data_source)
		if time_lines == None:
			exit(1)
		time_lines_of_products_of_all_data_sources[data_source] = time_lines
		return time_lines

	time_lines_of_packages_of_all_branches = {}
	time_lines_of_master_branch = {}

	# Get timelines of packages
	def TimeLinesOfPackages(branch, is_master_branch=False):
		global time_lines_of_packages_of_all_branches
		time_lines = time_lines_of_packages_of_all_branches.get(branch)
		if time_lines:
			return time_lines
		time_lines = mediator.QueryTimeLines(branch=branch)
		if time_lines == None or (not time_lines and is_master_branch):
			exit(1)
		if not time_lines:
			return time_lines_of_master_branch
		if is_master_branch:
			time_lines_of_packages_of_all_branches[branch] = time_lines
		return time_lines

	master_branch = common_params['master_branch']
	time_lines_of_master_branch = TimeLinesOfPackages(master_branch, True)

	# For all given branches
	for i, branch in enumerate(branches):
		# Getting related packages and time lines of packages
		related_packages = RelatedPackages(branch)
		current_versions_of_packages = CurrentVersionsOfPackages(branch)
		time_lines_of_packages = TimeLinesOfPackages(
			branch if branch != LOCAL_SYS else conf.GetLocalSysBranch())
		# For all mapping targets
		for j, data_source in enumerate(args.data_sources):
			# Getting branch map, ignored mapping pairs and time lines of products
			branch_map = BranchMap(branch, data_source)
			ignore = GetIgnoredPairs(data_source)
			supplementary_products, supplementary_products_flat = \
				SupplementaryProducts(data_source)
			time_lines_of_products = TimeLinesOfProducts(data_source)
			related_products = RelatedProducts(data_source)
			# Making the choice
			choices = solver.MakeChoice(
				branch,
				branch_map,
				related_packages,
				ignore,
				supplementary_products,
				supplementary_products_flat,
				current_versions_of_packages,
				time_lines_of_packages,
				time_lines_of_products,
				related_products,
				data_source == FSTEC_DATA_SRC,
				)
			if choices == None:
				exit(1)
			# Displaying results if there is a request to not update the DB
			if args.noupdate:
				for name, match in choices.items():
					printer.LineEnd(f'\t- {name} >> {match}')
			# Sending results to a src table
			elif choices and not mediator.SendMap(choices, branch, data_source):
				exit(1)
			if j < len(args.data_sources) - 1:
				printer.LineEnd('~~~ Going for a next data source ~~~')
		if i < len(branches) - 1:
			printer.LineEnd('~~~ Going for a next branch ~~~')

	exit(0)
