#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2025 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
import subprocess
from os.path             import join as path_join
from ax.filesystem       import PrepareDir
from cve_backup.helpers  import MySQLDump
from cve_manager.common  import Init, NewArgParser
from cve_manager.conf    import DB_CON_SEC, COMMON_SEC
from cve_manager.const   import DATA_SOURCES, NVD_DATA_SRC, FSTEC_DATA_SRC, \
	LINUX_KERNEL_VULNS, VUL_MARKERS_SHORT
from cve_manager.db      import DB
from cve_manager.desc    import ISSUES_PREP

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

argparser = argparse.ArgumentParser(description=ISSUES_PREP)
argparser.add_argument(
	'-t', '--types',
	metavar='ISSUES_TYPE', type=str, nargs='+', default=[],
	choices=DATA_SOURCES,
	help=f'Issues types ({" or ".join(DATA_SOURCES)}, all by default)'
	)
argparser = NewArgParser(ptype='bm', base=argparser)
args = argparser.parse_args()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initialising helper objects and reading the configuration file

printer, conf = Init(args)
mysql_config, common_params = conf.Get([DB_CON_SEC, COMMON_SEC])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get a set of names of all tables of the CVE database

def DatabaseTableNames():

	mysql_config_inf = {}
	for k, v in mysql_config.items():
		mysql_config_inf[k] = 'information_schema' if k == 'database' else v

	db = DB(printer)
	if not db.Connect(mysql_config_inf):
		return None

	cve_db_name = mysql_config.get('database', '')
	cols  = ['table_name']
	table = 'tables'
	cond  = f"table_schema = '{cve_db_name}'"

	cursor = db.Select(cols, table, cond)
	if not cursor:
		return None

	res = set()
	for table_name, in cursor:
		res.add(table_name)

	return res

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Dump short version of the CVE database (all needed tables to detect issues of
# specified type for a specified branch)

def Dump(types_of_issues, branch, db_tables, cve_manager_home):

	file_name = f'{branch}_{"_".join(types_of_issues)}.sql'
	file_path = path_join(cve_manager_home, file_name)

	printer.LineBegin('Preparing a short version of the CVE DB and writing it '
		f'to "{file_path}"')

	basic_tables = [LINUX_KERNEL_VULNS] + \
		[f'{t}_vul_import' for t in types_of_issues]
	branch_tables = [f'{branch}_src'] + \
		[f'{branch}_{VUL_MARKERS_SHORT.get(t, "")}_fixes'
			for t in types_of_issues] + \
		[f'{branch}_alt_errata']
	ignored_tables = db_tables - set(basic_tables + branch_tables)

	err = MySQLDump(file_path, mysql_config, ignored_tables=ignored_tables)
	if err:
		printer.Err(err)
		return False

	printer.Success()
	return True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	# Forming a list of branches that will be processed
	branches, err = \
		conf.SelectBranches(args.branches if 'all' not in args.branches else [])
	if err or not branches:
		printer.Err(err if err else 'Can\'t form a list of branch names')
		exit(1)

	# Determining the types of issues
	types_of_issues, _ = conf.SelectDataSources(args.types)
	if not types_of_issues:
		printer.Err('Can\'t form a list of data sources')
		exit(1)

	# Getting a set of names of all tables of the CVE database
	db_tables = DatabaseTableNames()
	if not db_tables:
		printer.Err('Can\'t get names of tables of the CVE database')
		exit(1)

	# Preparing the home dir
	cve_manager_home, err = PrepareDir(common_params['download'])
	if not cve_manager_home:
		printer.Err(err)
		exit(1)

	# Exporting data for detecting issues for all specified branches
	for branch in branches:
		if not Dump(types_of_issues, branch, db_tables, cve_manager_home):
			exit(1)

	exit(0)
