#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2024 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
from sys                        import argv
from collections                import defaultdict
from ax.ver                     import MergeVersions
from cve_manager.defines        import DATA_SOURCES, NVD_DATA_SRC, \
	FSTEC_DATA_SRC
from cve_manager.desc           import ISSUES
from cve_manager.common         import GetDef, NewArgParser, Init
from cve_manager.conf           import DB_CON_SEC, LOCAL_SYS
from cve_manager.parallel       import Parallel
from cve_manager.ignored_pairs  import IgnoredPairs
from cve_manager.package        import Package
from cve_manager.software       import SplitSoftware
from cve_issues.mediator        import Mediator
from cve_issues.excluded_issues import ExcludedIssues
from cve_issues.issues_checker  import IssuesChecker, VULNERABLE_VER
from cve_issues.vul_id          import VulID
from cve_issues.helpers         import GetLocalSysFullPackageNames

NO_UPDATE = 'the results will be printed on screen and the DB won\'t be updated'

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

argparser = argparse.ArgumentParser(description=ISSUES)
argparser.add_argument(
	'--prepare',
	action='store_true',
	help='Recreate issues tables'
	)
argparser.add_argument(
	'-t', '--types',
	metavar='ISSUES_TYPE', type=str, nargs='+', default=[],
	choices=DATA_SOURCES,
	help=f'Issues types ({" or ".join(DATA_SOURCES)}, all by default)'
	)
argparser = NewArgParser(ptype='bp', base=argparser)
argparser.add_argument(
	'--full_package_names',
	metavar='PACKAGE_NAME_VER_REL', type=str, nargs='+', default=[],
	help=f'Full package names (<name>-<version>-<release>), {NO_UPDATE}'
	)
argparser.add_argument(
	'--local_sys',
	action='store_true',
	help=f'Detect issues for all packages of the local system, {NO_UPDATE}'
	)
argparser.add_argument(
	'--noupdate',
	action='store_true',
	help='Do not update the issues tables, display the results on screen'
	)
argparser = NewArgParser(ptype='m', base=argparser)
args = argparser.parse_args()

if len(argv) < 2:
	argparser.print_help()
	exit(1)

if args.full_package_names or args.local_sys:
	args.noupdate = True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initialising helper objects and reading the configuration file

printer, conf = Init(args)
if args.noupdate:
	printer.AlertsOnly(True)
	printer.Plain(True)
mysql_config, = conf.Get([DB_CON_SEC])
mediator = Mediator(mysql_config, printer)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Handle a given portion of vulnerabilities (as a separate process)

VULNERABLE = 0b01
NOT_VULNERABLE = 0b10
VULNERABLE_AND_NOT_VULNERABLE = 0b11

def ProcessVulnerabilities(vul, issues_checker, send_pipe, err_event):

	err = ''
	data = defaultdict(set)
	issued_packages = defaultdict(lambda: defaultdict(set))
	buf = defaultdict(list)
	vulnerability_flag = defaultdict(int)

	# Collect data for each vulnerability
	for vul_year, vul_num, software_list, cve_ids_of_fstec_entry in vul:

		if err_event.is_set():
			break

		if not issues_checker.SetVulID(vul_year, vul_num, cve_ids_of_fstec_entry):
			err = issues_checker.Err()
			err_event.set()
			break

		# For each combination of vulnerable products of this vulnerability
		for vul_software_comb in SplitSoftware(software_list):

			buf.clear()
			vulnerability_flag = 0
			there_are_unprocessed_products = False

			# For each product of this combination
			for vul_software in vul_software_comb:

				vulnerable_by_itself = vul_software.IsVulnerable()

				if not issues_checker.SetVulSoftware(vul_software):
					there_are_unprocessed_products = True
					continue

				got_vulnerable_packages = False
				product_is_not_processed = True

				# For each package mapped to this product
				for i, data_of_mapped_package \
						in enumerate(issues_checker.DataOfMappedPackage()):

					package, cve_fixes, bdu_fixes = data_of_mapped_package
					package_is_set = \
						issues_checker.SetPackage(
							package, cve_fixes, bdu_fixes, issued_packages, i)

					if not package_is_set:
						if package_is_set == None:
							err = issues_checker.Err()
							err_event.set()
							break
						continue

					fix = issues_checker.Check()

					if not fix:
						if fix == None:
							err_event.set()
							break
						continue

					# Writing the data into temporary container
					buf[issues_checker.IssueKey()] = (fix, vulnerable_by_itself)

					if fix == VULNERABLE_VER:
						got_vulnerable_packages = True

					product_is_not_processed = False

				if err:
					break

				# Updating the flag that shows whether some packages of this
				# combination are vulnerable or none of them are not vulnerable
				vulnerability_flag |= \
					VULNERABLE if got_vulnerable_packages else NOT_VULNERABLE

				# If some package of this combination is excluded or not mapped
				# than this whole combination can't be considered vulnerable
				if product_is_not_processed:
					there_are_unprocessed_products = True

			for k, v in buf.items():
				fix, vulnerable_by_itself = v
				if not vulnerable_by_itself:
					continue
				package_id, package_variant, vul_ver = k
				if there_are_unprocessed_products or \
						vulnerability_flag == VULNERABLE_AND_NOT_VULNERABLE:
					fix = 'x' if fix == VULNERABLE_VER else fix + 'x'
				if args.noupdate:
					data[(vul_year, vul_num, package_id, package_variant, fix)] \
						.add(vul_ver)
				else:
					data[(vul_year, vul_num, package_id, fix)] \
						.add(vul_ver)

	# Merging vulnerable versions
	data = [k + (' '.join(MergeVersions(list(v))),) for k, v in data.items()]

	# Sending data for all detected issues at once
	if data and not err_event.is_set():
		if args.noupdate:
			for vul_year, vul_num, package_id, package_variant, fix, _ in data:
				vul_id = f'{issues_checker.VulMarker()}{vul_year}-{vul_num}'
				branch = issues_checker.TargetBranch()
				package_full_names = \
					issued_packages.get(package_id, {}).get(package_variant)
				if package_full_names:
					for package_full_name in package_full_names:
						print(vul_id, branch, package_full_name, fix)
				else:
					err = "Can't get a name of an issued package"
					err_event.set()
		else:
			err = mediator.SendIssues(
				issues_checker.TargetBranch(),
				issues_checker.DataSource(),
				data)
			if err:
				err_event.set()

	send_pipe.send([err])

	return

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Detect all CVE issues for a specified branch using multiple processes

def DetectIssues(branch, target, packages_selection, vul, kernel_fixes,
		excluded_issues, ignored_pairs):

	# Forming a list of package descriptions
	packages = mediator.GetPackages(branch, target, packages_selection,
		not args.full_package_names)
	if not packages:
		return bool(packages_selection)

	if args.full_package_names:
		for product, packages_data in packages.items():
			if product.startswith(f'_{branch}_'):
				for package_data in packages_data:
					if target == NVD_DATA_SRC:
						package, fixes, _ = package_data
					else:
						package, _, fixes = package_data
					software_list = f'-:{product}:(;{package.Ver().Value()}]:*'
					for vul_id in fixes:
						vul_year = vul_id[:4]
						vul_num  = vul_id[5:]
						# Add a new entry to the back of a last list of the list
						# of lists of vulnerabilities
						vul[-1].append([vul_year, vul_num, software_list, ''])

	issues_checker = IssuesChecker(
		branch,
		conf.GetLocalSysBranch(),
		target,
		packages,
		kernel_fixes,
		excluded_issues,
		ignored_pairs,
		not args.full_package_names
		)
	if not issues_checker.Ready():
		printer.Err(f'Wrong target "{target}"')
		return False

	msg = f'Detecting {target.upper()} issues for ' + \
		('the local system' if branch == LOCAL_SYS else f'{branch} branch')
	printer.LineBegin(msg)

	# Running multiple processes of issues detection
	warnings, ok = Parallel(ProcessVulnerabilities, vul, issues_checker)

	if not ok:
		printer.Err('Some process has terminated with an error')
		return False

	for warn in warnings:
		printer.LineAddExtra(warn)

	printer.Success()

	return True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	if args.full_package_names or args.local_sys:
		if args.packages or (args.full_package_names and args.local_sys):
			printer.Err('Only one of the following arguments can be set: '
				'"--packages" or "--full_package_names" or "--local_sys"')
			exit(1)
		if args.local_sys:
			args.full_package_names, err = GetLocalSysFullPackageNames()
			if err:
				printer.Err(err)
				exit(1)
		packages = defaultdict(list)
		try:
			for package in \
					[Package(full_name=p) for p in args.full_package_names]:
				packages[package.Name()] \
					.append((package.Ver().Value(), package.Rel()))
		except:
			printer.Err('Format of a full package name is '
				'"<name>-<version>-<release>"')
			exit(1)
	else:
		packages = {name : None for name in args.packages}

	# Forming a list of branches that will be processed
	branches, err = \
		conf.SelectBranches(args.branches if 'all' not in args.branches else [])
	if err or not branches:
		printer.Err(err if err else 'Can\'t form a list of branch names')
		exit(1)

	# Determining the issues types
	targets, _ = conf.SelectDataSources(args.types)
	if not targets:
		printer.Err('Can\'t form a list of data sources')
		exit(1)

	# Getting a dict of issues that should not be presented
	excluded_issues = ExcludedIssues(GetDef('CONF_DIR', args.debug))

	# Detecting issues for all specified types of issues
	for target in targets:
		# Getting a list of lists of vuls (each list for a separate process)
		vul = mediator.GetVulnerabilities(target, args.noupdate)
		if not vul and target != FSTEC_DATA_SRC:
			exit(1)
		# Getting linux kernel CVE fixes
		kernel_fixes = \
			mediator.GetLinuxKernelFixes() if target == NVD_DATA_SRC else {}
		# Reading a list of ignored mapping pairs
		printer.LineBegin(f'Loading {target.upper()} ignore list')
		ignored_pairs, msg = \
			IgnoredPairs(GetDef('CONF_DIR', args.debug), target).Read()
		if ignored_pairs == None:
			printer.Err(msg)
			exit(1)
		printer.Success(msg)
		# Detecting issues for all specified branches
		for branch in branches:
			if args.prepare:
				if not args.noupdate and \
						not mediator.RecreateIssuesTable(branch, target):
					exit(1)
			elif not mediator.CheckIssuesTable(branch, target):
				exit(1)
			if not DetectIssues(branch, target, packages, vul, kernel_fixes,
						excluded_issues, ignored_pairs) and \
					target != FSTEC_DATA_SRC:
				exit(1)

	exit(0)
