#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2022 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
import subprocess
import cve_manager.desc        as desc
from time                      import time, localtime, strftime, sleep
from sys                       import argv
from ax.printer                import Printer
from cve_manager.common        import NewArgParser
from cve_manager.intf_download import CmdDownload
from cve_manager.intf_map      import CmdMap

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Constants

# Names of the modules (without 'cve' prefix)
MANAGER  = 'manager'
BACKUP   = 'backup'
HISTORY  = 'history'
DOWNLOAD = 'download'
IMPORTER = 'import'
MAPPER   = 'map'
ISSUES   = 'issues'
MONITOR  = 'monitor'
MODULES  = [BACKUP, HISTORY, DOWNLOAD, IMPORTER, MAPPER, ISSUES]

# Paths of the modules
PATHS = {
	module: ('cpe-' if module == MAPPER else 'cve-') + module
	for module in MODULES
	}
DEBUG_PATHS = {
	module: './' + ('cve-import/bin/' if module == IMPORTER else '') + path
	for module, path in PATHS.items()
	}

# Descriptions of the modules
DESCRIPTIONS = {
	MANAGER:  desc.MANAGER,
	BACKUP:   desc.BACKUP,
	HISTORY:  desc.HISTORY,
	DOWNLOAD: desc.DOWNLOAD,
	IMPORTER: desc.IMPORTER,
	MAPPER:   desc.MAPPER,
	ISSUES:   desc.ISSUES,
	MONITOR:  desc.MONITOR,
	}

# Number of seconds before each new attempt to perform a step without errors
PAUSE = 60

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

argparser = argparse.ArgumentParser(description=desc.MANAGER)
argparser.add_argument(
	'-a', '--run_all',
	action='store_true',
	help='Run all the modules'
	)
argparser.add_argument(
	'-b', '--beginning_step',
	metavar='MODULE_NAME', type=str, choices=MODULES,
	help='Beginning step (so you could skip some of the early stages)'
	)
argparser.add_argument(
	'-e', '--ending_step',
	metavar='MODULE_NAME', type=str, choices=MODULES,
	help='Ending step (so you could skip some of the final stages)'
	)
argparser.add_argument(
	'-o', '--offline',
	action='store_true',
	help='Do not download vulnerability lists, acls, etc.'
	)
argparser.add_argument(
	'-r', '--retry',
	metavar='N_REPEATED_ATTEMPTS', type=int, default=0,
	help='Number of repeated attempts in case of a failure of executed step'
	)
argparser.add_argument(
	'-l', '--list_modules',
	action='store_true',
	help='List available modules (with descriptions)'
	)
argparser = NewArgParser(ptype='m', base=argparser)
args = argparser.parse_args()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Print descriptions of available modules

def ListModules():
	for module in MODULES:
		print(module + ':')
		for line in DESCRIPTIONS[module].splitlines():
			print('\t' + line)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Run a given module

def RunModule(module, params, printer):

	return_code = -1

	if module in MODULES:
		path = DEBUG_PATHS[module] if args.debug else PATHS[module]
		command = [path] + params.strip().split()
		if args.debug:
			command.append('--debug')
		if args.plain:
			command.append('--plain')
		if args.silent:
			command.append('--silent')
		printer.Div()
		printer.LineEnd(f'({strftime("%Y-%m-%d, %H:%M:%S", localtime())}) '
			f'Running "{" ".join(command)}"')
		t0 = time()
		try:
			completed_process = subprocess.run(command)
		except FileNotFoundError:
			printer.Err(f'Can\'t find {path}, perhaps the "--debug" flag is '
				f'{"unintentional" if args.debug else "missing"}')
			exit(1)
		dt = time() - t0
		printer.LineEnd(f'({dt:.2f} sec)'.replace('.', ','))
		if completed_process:
			return_code = completed_process.returncode
	else:
		printer.Err(f'There is no "{module}" module')

	return return_code

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	if len(argv) < 2:
		argparser.print_help()
		exit(1)

	if args.list_modules:
		ListModules()
		exit(0)

	if not args.run_all and not args.beginning_step and not args.ending_step:
		argparser.error('Missing required flag "-a" or "-b" or "-e"')
		exit(1)

	t0 = time()

	# Forming the steps
	steps  = [m for m in MODULES if not args.offline or m != DOWNLOAD]
	params = {
		BACKUP   : '--backup --store 2',
		HISTORY  : '--store 15',
		DOWNLOAD : CmdDownload(),
		IMPORTER : '--everything',
		MAPPER   : CmdMap(),
		ISSUES   : '--branches all --prepare',
		}

	# Find an index of a step with a given name
	def IndexOfTheStep(name):
		for i, step in enumerate(steps):
			if name == step:
				return i
		return -1
	# - nested func

	# Removing the steps that asked to be skipped
	begin = 0
	end = len(steps)
	if not args.run_all:
		if args.beginning_step:
			begin = IndexOfTheStep(args.beginning_step)
		if args.ending_step:
			end = IndexOfTheStep(args.ending_step)
	steps = steps[begin:(end + 1)]

	# Executing the steps
	printer = Printer(silent=args.silent, plain=args.plain)
	for step in steps:
		count = 0
		while True:
			return_code = RunModule(step, params[step], printer)
			if return_code == 0:
				break
			if step in (DOWNLOAD, MAPPER):
				if step == DOWNLOAD:
					f = CmdDownload
				elif step == MAPPER:
					f = CmdMap
				params[step] = f(return_code)
			if count == args.retry:
				exit(1)
			count += 1
			printer.LineEnd(f'Retrying in {PAUSE} seconds...')
			sleep(PAUSE)

	dt = time() - t0
	printer.LineEnd(f'\n(Total time {dt:.2f} sec)'.replace('.', ','))

	exit(0)
