#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2022 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
from ax.ver              import IsDateVer
from cve_manager.common  import NewArgParser, Init
from cve_manager.conf    import DB_CON_SEC
from cve_manager.db      import DB
from cve_manager.defines import DATA_SOURCES, VUL_MARKERS, NVD_DATA_SRC
from cve_issues.fix      import GetFixRecord
#from cve_issues.helpers  import ConvertFSTECVer

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initializing global objects, parsing the arguments and reading the conf file

argparser = argparse.ArgumentParser(description='Search for vulnerabilities '
	'of a package, that have been fixed in a given range of versions '
	'(prev_ver, cur_ver]')
argparser.add_argument(
	'-t', '--vul_type',
	metavar='VUL_TYPE', type=str, nargs='+', default=[], choices=DATA_SOURCES,
	help='Type of vulnerabilities to be considered '
	f'({" or ".join(DATA_SOURCES)}, all by default)'
	)
argparser.add_argument(
	'-b', '--branch',
	metavar='BRANCH_NAME', type=str, required=True,
	help='Name of a repository to be processed'
	)
argparser.add_argument(
	'-p', '--package',
	metavar='PACKAGE_NAME', type=str, required=True,
	help='Name of a package'
	)
argparser.add_argument(
	'--prev_ver',
	metavar='PREV_VER', type=str, required=True,
	help='Version of a package before the update'
	)
argparser.add_argument(
	'--cur_ver',
	metavar='CUR_VER', type=str, required=True,
	help='Current version of a package'
	)
argparser = NewArgParser(ptype='mm', base=argparser)
args = argparser.parse_args()

printer, conf = Init(args, monitor=True)
mysql_config, = conf.Get([DB_CON_SEC])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Query all issues of specified type (NVD, FSTEC), branch and package

def QueryIssues(vul_type, branch, package):

	vul_marker = VUL_MARKERS.get(vul_type)
	if not vul_marker:
		return None

	table_a = f'{branch}_{vul_type}_issues'
	table_b = f'{branch}_src'
	query = \
		f"SELECT CONCAT('{vul_marker}', a.vul_year, '-', a.vul_num), a.vul_ver " \
		f"FROM {table_a} a INNER JOIN {table_b} b ON a.package_id = b.id " \
		f"WHERE b.name = '{package}'"

	db = DB(printer)
	if not db.Connect(mysql_config):
		return None

	return db.Cursor(query, tables=[table_a, table_b])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	fixed_vulnerabilities = []
	vul_types, err = conf.SelectDataSources(args.vul_type)
	if err:
		printer.Err(err)
		exit(1)

	for vul_type in vul_types:

		issues = QueryIssues(vul_type, args.branch, args.package)
		if issues == None:
			printer.Err(f'Can\'t get {args.branch} {vul_type.upper()} issues')
			exit(1)

		is_nvd_data_source = (vul_type == NVD_DATA_SRC)

		for vul_id, vul_ver in issues:

			#if not is_nvd_data_source:
			#	vul_ver = ConvertFSTECVer(vul_ver)

			vul_ver_is_date = IsDateVer(vul_ver)
			vul_ver_is_any_ver = not vul_ver_is_date and vul_ver in ('(;)', '-')

			if vul_ver_is_any_ver or vul_ver_is_date == IsDateVer(args.prev_ver):
				if 'V' not in GetFixRecord(args.prev_ver, vul_ver):
					continue

			# Previous package version is vulnerable
			if vul_ver_is_any_ver or vul_ver_is_date == IsDateVer(args.cur_ver):
				if 'V' not in GetFixRecord(args.cur_ver, vul_ver):
					fixed_vulnerabilities.append(vul_id)

	for vul_id in fixed_vulnerabilities:
		print(vul_id)

	exit(0)
