#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2025 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
import os
from collections                import defaultdict
from cve_manager.common         import GetDef, NewArgParser, Init
from cve_manager.conf           import DB_CON_SEC
from cve_manager.const          import DATA_SOURCES, NVD_DATA_SRC, \
	FSTEC_DATA_SRC, VUL_MARKERS
from cve_manager.db             import DB
from cve_manager.software       import SplitSoftware
from cve_issues.excluded_issues import ExcludedIssues

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Constants

VUL_YEAR         = 0
VUL_NUM          = 1
SOFTWARE_LIST    = 2
PUBLISHED        = 3
CVSS_SCORE_V2    = 4
CVSS_VECTOR_V2   = 5
CVSS_SEVERITY_V2 = 6
CVSS_SCORE_V3    = 7
CVSS_VECTOR_V3   = 8
CVSS_SEVERITY_V3 = 9
CWE_ID           = 10
REFS             = 11
PATCH            = 12
SUMMARY          = 13
CVE_IDS          = 14

VUL_COLS = {
	VUL_YEAR         : 'vul_year',
	VUL_NUM          : 'vul_num',
	SOFTWARE_LIST    : 'software_list',
	PUBLISHED        : 'published',
	CVSS_SCORE_V2    : 'cvss_score_v2',
	CVSS_VECTOR_V2   : 'cvss_vector_v2',
	CVSS_SEVERITY_V2 : 'cvss_severity_v2',
	CVSS_SCORE_V3    : 'cvss_score_v3',
	CVSS_VECTOR_V3   : 'cvss_vector_v3',
	CVSS_SEVERITY_V3 : 'cvss_severity_v3',
	CWE_ID           : 'cwe_id',
	REFS             : 'refs',
	PATCH            : 'patch',
	SUMMARY          : 'summary',
	CVE_IDS          : 'cve_ids'
	}

ERR_CONN = 'Can\'t connect to the database'

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initializing global objects and parsing the arguments

argparser = argparse.ArgumentParser(
	description='Special module for working with trivy'
	)
argparser.add_argument(
	'-t', '--vul_types',
	metavar='VULNERABILITY_TYPES', type=str, nargs='+',
	choices=DATA_SOURCES,
	help=f'Monitor {" or ".join(DATA_SOURCES)} issues (all by default)'
	)
argparser.add_argument(
	'-o', '--out_dir',
	metavar='DIR_PATH', type=str, required=True,
	help=f'Dir path in which output files will be placed'
	)
argparser = NewArgParser(ptype='bm', base=argparser)
args = argparser.parse_args()

printer, conf = Init(args, monitor=False)
mysql_config, = conf.Get([DB_CON_SEC])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get a list of specified type of vulnerabilities from the database

def GetVulnerabilities(vul_type):

	db = DB()
	if not db.Connect(mysql_config):
		return None, ERR_CONN

	# Getting the list
	table = f'{vul_type}_vul_import'

	cols = []
	for i in range(len(VUL_COLS)):
		if (vul_type == FSTEC_DATA_SRC and i == CVSS_SEVERITY_V2) or \
				(vul_type == NVD_DATA_SRC and i == CVE_IDS):
			cols.append("''")
		elif (vul_type == FSTEC_DATA_SRC and i == CVSS_SEVERITY_V3):
			cols.append('severity')
		else:
			cols.append(VUL_COLS[i])
	cursor = db.Select(cols, table)
	if not cursor:
		if cursor == None:
			return None, f'Can\'t select {vul_type.upper()} data'
		return None, f'Empty list of {vul_type.upper()} vulnerabilities'

	return cursor.fetchall(), ''

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get a map of product names to package names

def PackagesMap(vul_type, branch):

	db = DB()
	if not db.Connect(mysql_config):
		return None, ERR_CONN

	table = branch + '_src'

	if not db.TableExists(table):
		return None, f'Table {table} does not exist'

	mapped_name = f"{vul_type}_name"
	cols = ["name", mapped_name]
	cond = f"{mapped_name} <> '' "

	cursor = db.Select(cols, table, cond)

	if not cursor:
		if cursor == None:
			return None, f'Can\'t select {vul_type.upper()} product names ' \
				f'mapped to {branch} packages'
		return {}, f'Empty map of {vul_type.upper()} product names ' \
			f'to {branch} packages'

	res = defaultdict(list)
	for package, products in cursor:
		for product in products.split():
			res[product].append(package)

	return res, ''

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Prepare a dir inside the output dir (<out_dir>/<branch>/<vul_marker>, for ex.
# /home/user/.cve-manager/Sisyphus/CVE)

def PrepareDir(vul_type, branch):

	dir_abs_path = os.path.join(
		os.path.abspath(args.out_dir),
		branch,
		VUL_MARKERS[vul_type][:-1]
		)

	if not os.path.isdir(dir_abs_path):
		try:
			os.makedirs(dir_abs_path)
		except OSError:
			return out_dir_abs_path, False

	return dir_abs_path, os.access(dir_abs_path, os.W_OK)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Replace product names with packages names in the software_list field of the
# NVD data using given map of product names to package names and then produce
# a vulnerability descriptions (one json file for each vulnerability)

def GenerateVulnerabilityDescriptions(vul_type, branch, vul):

	vul_marker = VUL_MARKERS.get(vul_type)
	if not vul_marker:
		return f'Wrong type of vulnerabilities "{vul_type}"'

	dir_abs_path, ok = PrepareDir(vul_type, branch)
	if not ok:
		return 'Can\'t prepare a dir for writing results, the path is ' \
			f'"{dir_abs_path}"'

	packages_map, err = PackagesMap(vul_type, branch)
	if err:
		return f'Can\'t get a map of {vul_type.upper()} product names ' \
			f'to {branch} packages: {err}'

	for entry in vul:

		package_state = ''

		# For each combination of vulnerable products of this vulnerability
		for vul_software_comb in SplitSoftware(entry[SOFTWARE_LIST]):

			# For each product of this combination
			for vul_software in vul_software_comb:

				for package in packages_map.get(vul_software.Product(), tuple()):

					package_state += f'''
      {{
      "cpe": "cpe:/a:{vul_software.FullName()}",
      "fix_state": "",
      "package_name": "{package}",
      "product_name": ""
      }},'''

		if not package_state:
			continue

		package_state = package_state[:-1]
		vul_id = f'{vul_marker}{entry[VUL_YEAR]}-{entry[VUL_NUM]}'
		severity = entry[CVSS_SEVERITY_V3] if entry[CVSS_SEVERITY_V3] \
			else entry[CVSS_SEVERITY_V2]

		txt = f'''{{
    "cvss3": {{
        "cvss3_base_score": "{entry[CVSS_SCORE_V3]}",
        "cvss3_scoring_vector": "{entry[CVSS_VECTOR_V3]}",
        "status": ""
        }},
    "cvss2": {{
        "cvss2_base_score": "{entry[CVSS_SCORE_V2]}",
        "cvss2_scoring_vector": "{entry[CVSS_VECTOR_V2]}",
        "status": ""
        }},
    "cwe": "{entry[CWE_ID]}",
    "details": [
        "{entry[SUMMARY]}"
        ],
    "mitigation": "",
    "name": "{vul_id}",
    "package_state": [{package_state}
    ],
    "public_date": "{entry[PUBLISHED]}",
    "statement": "",
    "threat_severity": "{severity}",
    "upstream_fix": ""
}}'''

		out_file_path = os.path.join(dir_abs_path, f'{vul_id}.json')

		try:
			with open(out_file_path, 'w') as f:
				f.write(txt)
		except OSError:
			return f'Can\'t write to {out_file_path}'

	return ''

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	# Determining the issues types
	vul_types, _ = conf.SelectDataSources(args.vul_types)
	if not vul_types:
		printer.Err('Can\'t form a list of data sources')
		exit(1)

	# Forming a list of branches that will be processed
	branches, err = \
		conf.SelectBranches(args.branches if 'all' not in args.branches else [])
	if err or not branches:
		printer.Err(err if err else 'Can\'t form a list of branch names')
		exit(1)

	# Getting a dict of issues that should not be presented
	excluded_issues = ExcludedIssues(GetDef('CONF_DIR', args.debug))

	for vul_type in vul_types:
		vul, err = GetVulnerabilities(vul_type)
		if err:
			printer.Err(err)
			exit(1)
		for branch in branches:
			err = GenerateVulnerabilityDescriptions(vul_type, branch, vul)
			if err:
				printer.Err(err)
				exit(1)

	exit(0)
