#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2025 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
from sys                        import argv
from collections                import defaultdict
from verlib.check               import IsDateVer
from verlib.cmp                 import CompareSVersions
from verlib.merge               import MergeVersions
from cve_manager.const          import DATA_SOURCES, NVD_DATA_SRC, \
	FSTEC_DATA_SRC, VUL_MARKERS, FIX_FALSE_COMPLEX_VUL, FIX_VULNERABLE_VER, \
	FIX_ALT_ERRATA_FIXED, FIX_ALT_ERRATA_VULNERABLE, FIX_ALT_ERRATA_EXCLUDED, \
	FIX_ALT_ERRATA_MISSING_VUL, LINUX_KERNEL_VULNS
from cve_manager.desc           import ISSUES
from cve_manager.common         import GetDef, NewArgParser, Init
from cve_manager.conf           import DB_CON_SEC, LOCAL_SYS
from cve_manager.parallel       import Parallel
from cve_manager.ignored_pairs  import IgnoredPairs
from cve_manager.package        import Package
from cve_manager.software       import SplitSoftware
from cve_issues.mediator        import Mediator
from cve_issues.excluded_issues import ExcludedIssues
from cve_issues.issues_checker  import IssuesChecker
from cve_issues.vul_id          import VulID
from cve_issues.helpers         import GetLocalSysFullPackageNames, \
	MergeKernelVulsWithRestVuls

NO_UPDATE = 'the results will be printed on screen and the DB won\'t be updated'

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

argparser = argparse.ArgumentParser(description=ISSUES)
argparser.add_argument(
	'--prepare',
	action='store_true',
	help='Recreate issues tables'
	)
argparser.add_argument(
	'-t', '--types',
	metavar='ISSUES_TYPE', type=str, nargs='+', default=[],
	choices=DATA_SOURCES,
	help=f'Issues types ({" or ".join(DATA_SOURCES)}, all by default)'
	)
argparser = NewArgParser(ptype='bp', base=argparser)
argparser.add_argument(
	'--full_package_names',
	metavar='PACKAGE_NAME_VER_REL', type=str, nargs='+', default=[],
	help=f'Full package names (<name>-<version>-<release>), {NO_UPDATE}'
	)
argparser.add_argument(
	'--local_sys',
	action='store_true',
	help=f'Detect issues for all packages of the local system, {NO_UPDATE}'
	)
argparser.add_argument(
	'--noupdate',
	action='store_true',
	help='Do not update the issues tables, display the results on screen'
	)
argparser = NewArgParser(ptype='m', base=argparser)
args = argparser.parse_args()

if len(argv) < 2:
	argparser.print_help()
	exit(1)

if args.full_package_names or args.local_sys:
	args.noupdate = True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initialising helper objects and reading the configuration file

printer, conf = Init(args)
if args.noupdate:
	printer.AlertsOnly(True)
	printer.Plain(True)
mysql_config, = conf.Get([DB_CON_SEC])
mediator = Mediator(mysql_config, printer)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Handle a given portion of vulnerabilities (as a separate process)

VULNERABLE = 0b01
NOT_VULNERABLE = 0b10
VULNERABLE_AND_NOT_VULNERABLE = 0b11

def ProcessVulnerabilities(vul, issues_checker, send_pipe, err_event):

	err = ''
	data = defaultdict(set)
	issued_packages = defaultdict(lambda: defaultdict(set))
	buf = defaultdict(list)

	# Collect data for each vulnerability
	for vul_year, vul_num, software_list, cve_ids_of_fstec_entry in vul:

		if err_event.is_set():
			break

		issues_checker.SetVulID(vul_year, vul_num, cve_ids_of_fstec_entry)

		# For each combination of vulnerable products of this vulnerability
		for vul_software_comb in SplitSoftware(software_list):

			buf.clear()
			vulnerability_flag = 0
			there_are_unprocessed_products = False

			# For each product of this combination
			for vul_software in vul_software_comb:

				if not issues_checker.SetVulSoftware(vul_software):
					there_are_unprocessed_products = True
					continue

				got_vulnerable_packages = False
				product_is_not_processed = True

				# For each package mapped to this product
				for i, data_of_mapped_package \
						in enumerate(issues_checker.DataOfMappedPackages()):

					package, cve_fixes, bdu_fixes = data_of_mapped_package

					if not issues_checker.SetPackage(
							package, cve_fixes, bdu_fixes, issued_packages, i):
						continue

					fix, err = issues_checker.Check()
					if err:
						err_event.set()
						break

					if not fix:
						continue

					# Writing the data into temporary container
					buf[issues_checker.IssueKey()] = \
						(fix, vul_software.IsVulnerable())

					if fix == FIX_VULNERABLE_VER:
						got_vulnerable_packages = True

					product_is_not_processed = False

				if err:
					break

				# Updating the flag that shows whether some packages of this
				# combination are vulnerable or none of them are not vulnerable
				vulnerability_flag |= \
					VULNERABLE if got_vulnerable_packages else NOT_VULNERABLE

				# If some package of this combination is excluded or not mapped
				# than this whole combination can't be considered vulnerable
				if product_is_not_processed:
					there_are_unprocessed_products = True

			for k, v in buf.items():
				fix, vulnerable_by_itself = v
				if not vulnerable_by_itself:
					continue
				package_id, package_variant, vul_ver = k
				if there_are_unprocessed_products or \
						vulnerability_flag == VULNERABLE_AND_NOT_VULNERABLE:
					fix = FIX_FALSE_COMPLEX_VUL if fix == FIX_VULNERABLE_VER else \
						fix + FIX_FALSE_COMPLEX_VUL
				if args.noupdate:
					data[(vul_year, vul_num, package_id, package_variant, fix)] \
						.add(vul_ver)
				else:
					data[(vul_year, vul_num, package_id, fix)] \
						.add(vul_ver)

	# Merging vulnerable versions
	data = [k + (' '.join(MergeVersions(list(v))),) for k, v in data.items()]

	# Sending data of all detected issues at once
	if data and not err_event.is_set():
		if args.noupdate:
			for vul_year, vul_num, package_id, package_variant, fix, _ in data:
				vul_id = f'{issues_checker.VulMarker()}{vul_year}-{vul_num}'
				branch = issues_checker.TargetBranch()
				package_full_names = \
					issued_packages.get(package_id, {}).get(package_variant)
				if package_full_names:
					for package_full_name in package_full_names:
						print(vul_id, branch, package_full_name, fix)
				else:
					err = "Can't get a name of an issued package"
					err_event.set()
		else:
			err = mediator.SendIssues(
				issues_checker.TargetBranch(),
				issues_checker.DataSource(),
				data)
			if err:
				err_event.set()

	send_pipe.send([err])

	return

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Handle a given portion of ALT errata (as a separate process)
# alt_errata is {package_name: {package_ver: {package_rel: {vul_ids}},
# packages_nvr is {(package_name: (id, ver, rel)}

def ProcessAltErrata(alt_errata, data_needed_for_alt_errata_processing,
		send_pipe, err_event):

	branch, data_source, packages_nvr, available_vul_ids, excluded_issues = \
		data_needed_for_alt_errata_processing
	vul_marker = VUL_MARKERS.get(data_source, '???-')
	data = defaultdict(set)
	warn = []
	is_full_branch = not (args.packages or args.full_package_names or
		args.local_sys)

	def Details(name, ver_a, ver_b):
		return f'{name}-{ver_a} <> {name}-{ver_b}'
	# - nested func

	for package_name, errata_package_versions in alt_errata.items():
		package_info = packages_nvr.get(package_name)
		if not package_info:
			if is_full_branch:
				warn.append(f'There is no "{package_name}" package '
					'in this branch')
			continue
		for package_id, package_ver, package_ver_is_date, package_rel in \
				package_info:
			for errata_package_ver, errata_package_releases \
					in errata_package_versions.items():
				if IsDateVer(errata_package_ver) != package_ver_is_date:
					w = 'Comparison of date version and non-date version: ' + \
						Details(package_name, package_ver, errata_package_ver)
					warn.append(w)
					continue
				ver_cmp_res, ver_cmp_symbolic = \
					CompareSVersions(package_ver, errata_package_ver)
				if abs(ver_cmp_res) == 1 and 0 < ver_cmp_symbolic < 3:
					w = 'Comparison of symbolic and non-symbolic versions: ' + \
						Details(package_name, package_ver, errata_package_ver)
					warn.append(w)
					continue
				for errata_package_rel, vul_ids in errata_package_releases.items():
					package_is_vulnerable = ver_cmp_res < 0
					if ver_cmp_res == 0 and errata_package_rel:
						rel_cmp_res, _ = \
							CompareSVersions(package_rel, errata_package_rel)
						package_is_vulnerable = rel_cmp_res < 0
					fix0 = FIX_ALT_ERRATA_VULNERABLE if package_is_vulnerable \
						else FIX_ALT_ERRATA_FIXED
					for vul_year, vul_num in vul_ids:
						fix = fix0
						vul_id = f'{vul_marker}{vul_year}-{vul_num}'
						if vul_num not in available_vul_ids.get(vul_year, set()):
							fix += FIX_ALT_ERRATA_MISSING_VUL
						else:
							# Checking that this issue is excluded
							is_excluded_issue, err = excluded_issues.IsExcluded(
								branch, vul_id, package_name, package_ver, package_rel)
							if is_excluded_issue == None:
								w = f'Can\'t check whether {vul_id} is excluded for ' \
									f'package {package_name}-{package_ver}-{package_rel}'
								if err:
									w += f': {err}'
								warn.append(w)
								continue
							if is_excluded_issue:
								fix += FIX_ALT_ERRATA_EXCLUDED
						package_specifier = \
							f'{package_name}-{package_ver}-{package_rel}' \
							if args.noupdate else package_id
						data[(vul_year, vul_num, package_specifier, ''.join(sorted(fix)))] \
							.add(f'(;{errata_package_ver})')

	# Merging vulnerable versions
	data = [k + (' '.join(MergeVersions(list(v))),) for k, v in data.items()]

	# Sending data for all detected issues at once
	if data and not err_event.is_set():
		if args.noupdate:
			for vul_year, vul_num, package_full_name, fix, _ in data:
				vul_id = f'{vul_marker}{vul_year}-{vul_num}'
				print(vul_id, branch, package_full_name, fix)
		else:
			err = mediator.SendIssues(branch, data_source, data, ignore=True)
			if err:
				err_event.set()

	send_pipe.send(warn)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Convert {map_name: [(Package_0, _, _), .. , ((Package_n, _, _)]} to
# {name: [(id, ver_0, ver_0_is_date, rel_0), .. ,
#     (id, ver_n, ver_n_is_date, rel_n)]}

def GetPackagesNameVerRel(packages):

	res = defaultdict(list)
	for packages_mapped_to_specific_product in packages.values():
		for package, _, _ in packages_mapped_to_specific_product:
			res[package.Name()].append((
				package.ID(),
				package.Ver().Value(),
				package.Ver().IsDate(),
				package.Rel()
				))

	return res

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Detect all CVE issues for a specified branch using multiple processes

def DetectIssues(branch, target, packages_selection, vul, excluded_issues,
		ignored_pairs):

	# Getting ALT errata
	alt_errata = mediator.GetAltErrata(branch, target)
	if alt_errata == None:
		return False

	# Getting a list of package descriptions
	packages = mediator.GetPackages(branch, target, packages_selection,
		not args.full_package_names)
	if not packages:
		return False if packages == None else bool(packages_selection)

	# Getting name, ver and release of packages
	packages_nvr = GetPackagesNameVerRel(packages)
	if not packages_nvr:
		return False if packages_nvr == None else bool(packages_selection)

	# Detect issues for packages that are not mapped
	if args.full_package_names:
		count = 0
		for package_name, packages_data in packages.items():
			if package_name.startswith(f'_{branch}_'):
				for package_data in packages_data:
					if target == NVD_DATA_SRC:
						package, fixes, _ = package_data
					else:
						package, _, fixes = package_data
					software_list = f'-:{package_name}:(;{package.Ver().Value()}]:*'
					for vul_id in fixes:
						vul_year = vul_id[:4]
						vul_num  = vul_id[5:]
						# Add a new entry to the back of a last list of the list
						# of lists of vulnerabilities
						vul[count].append([vul_year, vul_num, software_list, ''])
						count += 1
						if count >= len(vul):
							count = 0

	issues_checker = IssuesChecker(
		branch,
		conf.GetLocalSysBranch(),
		target,
		packages,
		excluded_issues,
		ignored_pairs,
		not args.full_package_names
		)
	if not issues_checker.Ready():
		printer.Err(f'Wrong target "{target}"')
		return False

	available_vul_ids = defaultdict(set)
	for v in vul:
		for v_year, v_num, _, _ in v:
			available_vul_ids[v_year].add(v_num)

	msg = f'Detecting {target.upper()} issues for ' + \
		('the local system' if branch == LOCAL_SYS else f'{branch} branch')
	printer.LineBegin(msg)

	# First, process the list of vulnerabilities, and then process ALT errata
	data_needed_for_alt_errata_processing = [branch, target, packages_nvr,
		available_vul_ids, excluded_issues]
	stages = [
		(ProcessVulnerabilities,
			vul, issues_checker),
		(ProcessAltErrata,
			alt_errata, data_needed_for_alt_errata_processing),
		]

	for stage in stages:
		# Running multiple processes of issues detection
		warnings, ok = Parallel(*stage)
		if not ok:
			printer.Err('Some process has terminated with an error')
			return False
		for warn in warnings if len(warnings) < 2 else sorted(set(warnings)):
			printer.LineAddExtra(warn)

	printer.Success()

	return True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	if args.full_package_names or args.local_sys:
		if args.packages or (args.full_package_names and args.local_sys):
			printer.Err('Only one of the following arguments can be set: '
				'"--packages" or "--full_package_names" or "--local_sys"')
			exit(1)
		if args.local_sys:
			args.full_package_names, err = GetLocalSysFullPackageNames()
			if err:
				printer.Err(err)
				exit(1)
		packages = defaultdict(list)
		try:
			for package in \
					[Package(full_name=p) for p in args.full_package_names]:
				packages[package.Name()] \
					.append((package.Ver().Value(), package.Rel()))
		except:
			printer.Err('Format of a full package name is '
				'"<name>-<version>-<release>"')
			exit(1)
	else:
		packages = {name : None for name in args.packages}

	# Forming a list of branches that will be processed
	branches, err = \
		conf.SelectBranches(args.branches if 'all' not in args.branches else [])
	if err or not branches:
		printer.Err(err if err else 'Can\'t form a list of branch names')
		exit(1)

	# Determining the issues types
	targets, _ = conf.SelectDataSources(args.types)
	if not targets:
		printer.Err('Can\'t form a list of data sources')
		exit(1)

	# Getting a dict of issues that should not be presented
	excluded_issues = ExcludedIssues(GetDef('CONF_DIR', args.debug))

	# Detecting issues for all specified types of issues
	for target in targets:
		# Getting a list of lists of vuls (each list for a separate process)
		if target == NVD_DATA_SRC:
			vul_kernel = mediator.GetVulnerabilities(LINUX_KERNEL_VULNS)
			vul_rest = mediator.GetVulnerabilities(target)
			vul = MergeKernelVulsWithRestVuls(vul_kernel, vul_rest)
		else:
			vul = mediator.GetVulnerabilities(target)
		if not vul:
			exit(1)
		# Reading a list of ignored mapping pairs
		printer.LineBegin(f'Loading {target.upper()} ignore list')
		ignored_pairs, msg = \
			IgnoredPairs(GetDef('CONF_DIR', args.debug), target).Read()
		if ignored_pairs == None:
			printer.Err(msg)
			exit(1)
		printer.Success(msg)
		# Detecting issues for all specified branches
		for branch in branches:
			if args.prepare:
				if not args.noupdate and \
						not mediator.RecreateIssuesTable(branch, target):
					exit(1)
			elif not mediator.CheckIssuesTable(branch, target):
				exit(1)
			if not DetectIssues(branch, target, packages, vul, excluded_issues,
						ignored_pairs) and \
					target != FSTEC_DATA_SRC:
				exit(1)

	exit(0)
