#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2025 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import os
import re
import gzip
import zipfile
import argparse
import subprocess
from sys                     import argv
from shutil                  import copyfileobj, rmtree
from requests                import get as requesturl
from ax.filesystem           import PrepareDir
from cve_manager             import const
from cve_manager.desc        import DOWNLOAD
from cve_manager.common      import NewArgParser, Init, ReviseCVEVolumes
from cve_manager.conf        import COMMON_SEC
from cve_manager.intf_common import ErrEncode, ERR_MAX

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

class YearAction(argparse.Action):
	def __call__(self, parser, namespace, values, option_string=None):
		got_nvd = getattr(namespace, 'nvd_vul', False)
		if got_nvd:
			setattr(namespace, self.dest, values)
		else:
			print('"year" param used only with "nvd_vul" param!')

argparser = argparse.ArgumentParser(description=DOWNLOAD)
argparser.add_argument(
	'-i', '--info',
	metavar='INFO', type=str, required=True,
	nargs='+', choices=(const.ALL_DOWNLOADS + ['all']),
	help=f'Information to download ({", ".join(const.ALL_DOWNLOADS)} or all)'
	)
argparser.add_argument(
	'-n', '--noreplace',
	action='store_true',
	help='Do not replace existing files'
	)
argparser.add_argument(
	'-v', '--vols',
	default='all',
	action=YearAction,
	metavar='<year>|recent', type=str, nargs='+',
	help='Volumes of the NVD CVE lists (all by default)'
	)
argparser.add_argument(
	'-r', '--retry',
	action='store_true',
	help='This is not a first run (the module has been terminated and now it '
	'is running again)'
	)
argparser = NewArgParser(base=argparser, ptype='m')
args = argparser.parse_args()

if 'all' in args.info:
	args.info = const.ALL_DOWNLOADS

checklist = {target: False for target in args.info}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Terminate with a return code formed from a list of those download targets
# that have not been completed

def Err(returncode=None):

	return returncode if returncode != None else \
		exit(ErrEncode(checklist, const.ALL_DOWNLOADS))

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Initialising a printer and reading the configuration file

printer, conf = Init(args)
common_params, = conf.Get([COMMON_SEC])
available_branches, err = conf.SelectBranches()
if err:
	printer.Err(err)
	Err()
distro_lists = conf.GetDistroLists(available_branches)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Select NVD volumes that will be downloaded

def SelectNVDVolumes():

	nvd_vols = ReviseCVEVolumes(args.vols, printer)
	if not nvd_vols:
		Err(ERR_MAX)
	if not args.vols:
		nvd_vols.append('recent')
	return nvd_vols

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Generate REST API queries to ALT Linux RDB

def AltErrata(branch):

	netloc = 'rdb.altlinux.org'
	path   = 'api/errata/export/oval/' + branch
	query  = 'one_file=true'
	return f'https://{netloc}/{path}?{query}'

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Generate a map of download targets using a map of distro lists specified for
# various branches

def DistroTargets(distro_lists):

	res = {}
	for type_of_list, distro_lists_of_branches in distro_lists.items():
		for branch_name, distro_list in distro_lists_of_branches.items():
			distro_name, distro_list_path = distro_list
			target = f'{type_of_list}/{branch_name}/{distro_name}'
			res[target] = [distro_list_path]
	return res

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Forming possible download targets

NVD_FEEDS_URL = 'https://nvd.nist.gov/feeds/'
CPE_URL       = f'{NVD_FEEDS_URL}json/cpe/2.0/nvdcpe-2.0.zip'
NVD_VUL_URL   = f'{NVD_FEEDS_URL}json/cve/2.0/'
FSTEC_VUL_URL = 'https://bdu.fstec.ru/files/documents/vulxml.zip'
LINUX_KERNEL_VULNS0  = 'git://git.kernel.org/pub/scm/linux/security/vulns.git'
LINUX_KERNEL_VULNS1  = 'https://kernel.googlesource.com/pub/scm/linux/security/vulns.git'

ACLS_URL      = 'git.altlinux.org::acl/'
TIMELINES_URL = 'ftp://ftp.altlinux.org/pub/distributions/archive/'

SIMPLE_REQUEST_TARGETS = {
	const.DT_CPE: [[CPE_URL]],
	const.DT_NVD_VUL: [[f'{NVD_VUL_URL}nvdcve-2.0-{vol}.json.gz'] for vol in SelectNVDVolumes()],
	const.DT_FSTEC_VUL: [[FSTEC_VUL_URL]],
	}

COMPLEX_REQUEST_TARGETS = {
	const.DT_ACL: [ACLS_URL],
	const.DT_TIMELINES: [f'{TIMELINES_URL}{common_params["master_branch"].lower()}/index/src/'],
	const.DT_DISTRO_SRC: [common_params.get('distro_lists_src')],
	const.DT_DISTRO_BIN: [common_params.get('distro_lists_bin')],
	const.DT_ALT_ERRATA: [[AltErrata(branch) for branch in available_branches
		if branch.lower() != 'sisyphus']],
	const.DT_LINUX_KERNEL_VULNS: [LINUX_KERNEL_VULNS0, LINUX_KERNEL_VULNS1],
	} | \
	DistroTargets(distro_lists)

CUSTOM_FILE_NAMES = {
	}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Decorator function that prints out given message, tries to execute decorated
# function, checks the resulting string (including with the use of optional
# check function), prints a status message and returns the resulting string

def TryAndHandle(func):

	def Wrapper(msg, params, optional_check=None, verbose=True):
		printer.LineBegin(msg)
		try:
			res = func(*params)
			err_msg = ''
		except Exception as e:
			res = ''
			err_msg = str(e)
		if not res or (optional_check and not optional_check(res)):
			printer.Err(err_msg)
		else:
			if verbose:
				printer.LineAddExtra(f'{res}')
			printer.Success()
		return res

	return Wrapper

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Downloading, unzipping and then removing a zip (gz) file

REQUEST_HEADERS = {'User-Agent': 'cve-manager'}

# Download a gz/zip file from url and place it into base/<gz/zip-name> dir
def Request(url, base, noreplace):
	custom_file_name = CUSTOM_FILE_NAMES.get(url)
	file_name = custom_file_name if custom_file_name else url.split('/')[-1]
	file_path = os.path.join(base, file_name)
	if noreplace and os.path.exists(file_path.rstrip('.gz')):
		return 'Noreplace'

	@TryAndHandle
	def Retrive(url, file_path):
		verify = common_params['always_verify_tls_certificate'] or \
			not args.retry or \
			args.info[0] != const.DT_FSTEC_VUL or \
			url != FSTEC_VUL_URL
		r = requesturl(url, headers=REQUEST_HEADERS, timeout=10, verify=verify)
		if not r.ok:
			return ''
		with open(file_path, 'wb') as f:
			f.write(r.content)
		return file_path

	return Retrive(f'Downloading "{url}"', [url, file_path])

# Unarchive specified file and remove initial archive
def ProcessArchiveFile(file_path, out_dir_path=None):
	# Picking the right unarchiver
	if file_path.endswith('.gz'):
		Unarchive = Ungzip
	elif file_path.endswith('.zip'):
		Unarchive = Unzip
	else:
		# This is not an archive file, nothing to do
		return True
	msg = f'Unzipping "{file_path}"'
	if not Unarchive(msg, [file_path, out_dir_path], os.path.exists):
		return False
	msg = f'Removing "{file_path}"'
	cond = lambda f : not os.path.exists(f)
	if not Remove(msg, [file_path], cond, verbose=False):
		return False
	return True

# Unarchive a gz file
@TryAndHandle
def Ungzip(file_path, out_dir_path=None):
	with gzip.open(file_path, 'rb') as f_in:
		if out_dir_path:
			with open(out_dir_path, 'wb') as f_out:
				copyfileobj(f_in, f_out)
				return out_dir_path
		file_path = file_path.rstrip('.gz')
		with open(file_path, 'wb') as f_out:
			copyfileobj(f_in, f_out)
			return file_path
	return ''

# Unarchive a zip file
@TryAndHandle
def Unzip(file_path, out_dir_path=None):
	with zipfile.ZipFile(file_path, 'r') as f_zip:
		if out_dir_path:
			f_zip.extractall(out_dir_path)
			return out_dir_path
		file_path = file_path.rstrip('.zip')
		f_zip.extractall(file_path)
		return file_path
	return ''

# Remove a gz/zip file
@TryAndHandle
def Remove(file_path):
	if os.path.isfile(file_path):
		os.remove(file_path)
	return file_path

# Download a gz/zip file, unzip it to base/<gz/zip-name> and rm this gz/zip file
def GetContents(url, base, noreplace=args.noreplace):
	# Downloading an archive file
	file_path = Request(url, base, noreplace)
	if not file_path:
		return False
	if file_path == 'Noreplace':
		printer.LineEnd(f'[NOTE: Won\'t download "{url}", the data exists '
			f'in "{base}"]')
		return True
	return ProcessArchiveFile(file_path)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Fetch data using wget, git or another utulity

def FetchData(target, src, dest):

	printer.LineBegin(f'Downloading "{src}"')

	cmd = []
	shell = False
	timeout = None

	if target == 'acl':
		cmd += ['rsync', '--timeout=30', '-aq', src]
	elif target == 'alt_errata':
		branch_name = re.sub(r'\?.*', '', re.sub(r'.*\/', '', src))
		dest_dir = dest
		dest = os.path.join(dest_dir, f'{branch_name}.zip')
		cmd += ['mkdir', '-p', dest_dir, '&&', 'wget', src, '-O']
		shell = True
	elif target == 'timelines':
		cmd += ['wget', '--timeout=10', '-rq', src, '-A', 'd-t-s-evr.list',
			'-nH', '--cut-dirs', '3', '-P']
	elif target in (const.DT_DISTRO_SRC, const.DT_DISTRO_BIN, const.DT_LINUX_KERNEL_VULNS):
		if src[0] == '/':
			if target == const.DT_DISTRO_BIN:
				cmd += ['install', '-m', '644', os.path.join(src, '*.list'), '-Dt']
			else:
				cmd += ['rsync', '-a', '--exclude=".*"', os.path.join(src, '*')]
			shell = True
		elif src.startswith('git://') or src.endswith('.git'):
			timeout = 180
			# Destination dir should be removed before cloning with git
			if os.path.exists(dest):
				try:
					rmtree(dest)
				except Exception as exc:
					printer.Err('Unable to recreate the destination dir '
						f'"{dest}": {exc}')
					return False
			cmd += ['git', 'clone', src]
		else:
			printer.Err(f'Invalid value "{src}" of a conf parameter')
			return False
	elif 'distro' in target:
		cmd += ['install', '-m', '644', src, '-D']
		shell = True
	else:
		printer.Err(f'Wrong target "{target}"')
		return False

	cmd += [dest]
	with open(os.devnull, 'w') as null:
		try:
			completed_process = subprocess.run(
				' '.join(cmd) if shell else cmd,
				stdout=null,
				stderr=subprocess.STDOUT,
				shell=shell,
				timeout=timeout,
				)
		except Exception as exc:
			printer.Err(str(exc))
			return False

	if not completed_process or completed_process.returncode != 0:
		msg = f'The "{" ".join(cmd)}" command returned code {completed_process.returncode}' \
			if completed_process else ''
		try:
			if os.path.exists(dest):
				rmtree(dest) if os.path.isdir(dest) else os.remove(dest)
		except Exception as exc:
			msg += f'; Unable to delete the output file "{dest}: {exc}"'
		if target in const.REQUIRED_DOWNLOADS:
			printer.Err(msg)
		else:
			printer.Warn(msg)
		return False

	if target == 'alt_errata':
		if not ProcessArchiveFile(dest, dest_dir):
			printer.Err(f'Unable to unarchive the file "{dest}"')
			return False

	printer.Success()
	printer.LineEnd(f'\t- {dest}')

	return True

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Filter the map of all possible download targets using the map representing
# the user's choice

def FilterDownloadTargets(source_map):

	return {k: v for k, v in source_map.items() for el in args.info
		if k.startswith(el)}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Convert a string of sources to a list of sources (git-sources will be placed
# in the beginning of the list)

def ConvStrToList(src):

	if type(src) is not list:
		return [src], ''

	# Determining the index of a git-source
	index = -1
	for i, el in enumerate(src):
		if el.startswith('git://'):
			if index > 0:
				return src, 'More than one git-source'
			index = i

	if index < 0:
		return src, ''

	return src[index:index + 1] + src[:index] + src[index + 1:], ''

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	if len(argv) < 2:
		argparser.print_help()
		Err(ERR_MAX)

	# Checking and converting download path
	base, err = PrepareDir(common_params['download'])
	if not base:
		printer.Err(err)
		Err(ERR_MAX)

	# Forming a dict of download targets
	simple_targets = FilterDownloadTargets(SIMPLE_REQUEST_TARGETS)
	complex_targets = FilterDownloadTargets(COMPLEX_REQUEST_TARGETS)

	# Filtering excluded data sources from the dict
	excluded_url_targets = [t for t in simple_targets if
		t.endswith('_vul') and
		t[:-4] in common_params.get('excluded_vulsrc', [])]
	if excluded_url_targets:
		simple_targets = {k: v for k, v in simple_targets.items()
			if k not in excluded_url_targets}
		msg = f'{", ".join(excluded_url_targets)} ' \
			f'{"are" if len(excluded_url_targets) > 1 else "is"} excluded'
		printer.LineEnd(f'[NOTE: {msg}]')

	# Download the simple targets using the internal Python functionality
	for target, sub_targets in simple_targets.items():
		target_check = True
		for alternatives in sub_targets:
			sub_target_check = False
			for url in alternatives:
				if GetContents(url, base):
					sub_target_check = True
					break
			if not sub_target_check and target in const.REQUIRED_DOWNLOADS:
				Err()
			target_check &= sub_target_check
		checklist[target] = target_check

	# Download the complex targets using various utilities like git and wget
	for target, alternatives in complex_targets.items():
		target_check = False
		for src in alternatives:
			if target_check or not src:
				continue
			src, err = ConvStrToList(src)
			if err:
				printer.Err(err)
				Err()
			dest = os.path.join(base, target)
			if dest[-1] != '/' and ('distro' not in target or \
					target in (const.DT_DISTRO_SRC, const.DT_DISTRO_BIN)):
				dest += '/'
			target_check = True
			for _src in src:
				if not FetchData(target, _src, dest):
					target_check = False
					if target in const.REQUIRED_DOWNLOADS:
						Err()
			if target == 'alt_errata':
				# Rename ALT errata <branch>_<date>.xml files to <branch>.xml
				for file_name in os.listdir(dest):
					m = re.search(r'^([a-z0-9]+)', file_name)
					if m:
						file_path0 = os.path.join(dest, file_name)
						file_path  = os.path.join(dest, f'{m.group(1)}.xml')
						try:
							os.rename(file_path0, file_path)
						except Exception as exc:
							printer.Err(str(exc))
							Err()
		checklist[target] = target_check

	exit(0)
