#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2021 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
from collections                       import defaultdict
from ax.text                           import ReplaceNotAlfa
from cve_manager_knowledge_miner.query import Query

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

TARGETS  = ('packages', 'related_packages', 'cpe', 'fstec_products',
	'used_vendors', 'unused_vendors', 'refs', 'product_marks')
BRANCHES = ('Sisyphus', 'p9', 'c9m2', 'c9f2', 'c81', 'c7')

argparser = argparse.ArgumentParser()
argparser.add_argument(
	'-t', '--target',
	metavar='TARGET', type=str, choices=TARGETS, required=True,
	help=' | '.join(TARGETS)
	)
argparser.add_argument(
	'-b', '--branches',
	metavar='BRANCH_NAME', type=str, nargs='+', default = BRANCHES,
	help=f'[{" ".join(BRANCHES)}] by default'
	)
argparser.add_argument(
	'--debug',
	action='store_true',
	help='Run in debug mode'
	)
args = argparser.parse_args()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# For packages that have related names print names of those that are mapped to
# a same product and those that are mapped to different products

def TwoGroupsOfRelatedPackages(packages):

	packages = {name: (nvd_name, set(related_names.split(', ')))
		for name, nvd_name, related_names in packages}

	K_SAME = 'same'
	K_DIFF = 'different'
	d = {K_SAME: [], K_DIFF: []}
	register = set()

	for name, properties in packages.items():
		nvd_name, related_names = properties
		for related_name in related_names:
			if not packages.get(related_name) or \
					any(pair in register for pair in
						((name, related_name), (related_name, name))):
				continue
			k = K_SAME if packages[related_name][0] == nvd_name else K_DIFF
			v = f'{name} >> {nvd_name}, {related_name} >> {packages[related_name][0]}'
			d[k].append(v)
			register.add((name, related_name))

	for k, v in d.items():
		print(f'Packages that have related names and that are mapped to {k} '
			'product names:')
		print('-- None --' if not v else '\n'.join(v))
		print()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Determine and print used product names and their vendors

def MappedProductsOfVendors(reverse_branch_map, products_of_vendors):

	mapped_products = defaultdict(list)

	for vendor, products in products_of_vendors.items():
		for product in products:
			packages = reverse_branch_map.get(product, [])
			for package in packages:
				mapped_products[vendor].append(
					{'package': package, 'product': product})

	for vendor, matches in mapped_products.items():
		for match in matches:
			print(
				f'[vendor: {vendor}],',
				f'[package: {match.get("package", "")}],',
				f'[product: {match.get("product", "")}]'
				)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Determine and print vendors that have no mapped products

def VendorsWithNoMappedProducts(reverse_branch_map, products_of_vendors):

	used_vendors = set()

	for product in reverse_branch_map:
		for vendor, products in products_of_vendors.items():
			if product in products:
				used_vendors.add(vendor)

	unused_vendors = list()

	for vendor, products in products_of_vendors.items():
		if vendor not in used_vendors:
			pair = (vendor, len(products))
			unused_vendors.append(pair)

	for vendor, n_products in \
			sorted(unused_vendors, key=lambda pair: pair[1], reverse=True):
		additional_info = f'{n_products} product{"s" if n_products > 1 else ""}'
		print(vendor, f'({additional_info})')

	print(len(unused_vendors), 'unused vendors total')

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Query all product names and detect and print most frequently used prefixes
# and suffixes

def __PrintRate(rate):

	for mark, count in \
			sorted(rate.items(), key=lambda item: item[1], reverse=True):
		if count < 5:
			break
		print(f'\t{mark}: {count}')


def RateProductNamesPrefixesAndSuffixes(products):

	DELIM = '_'
	prefix_rate = defaultdict(int)
	suffix_rate = defaultdict(int)

	for product in products:
		product_splitted = [word
			for word in ReplaceNotAlfa(product, DELIM).split(DELIM) if word]
		if len(product_splitted) < 2:
			continue
		prefix_rate[product_splitted[0]] += 1
		suffix_rate[product_splitted[-1]] += 1

	print('Prefix rate:')
	__PrintRate(prefix_rate)

	print('\nSuffix rate:')
	__PrintRate(suffix_rate)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	query = Query(args)

	if 'packages' in args.target:
		if args.target == 'packages':
			query.Packages()
		# target is 'related_packages'
		else:
			additional_fiedls = ['nvd_name', 'relatives']
			condition = "relatives <> '' AND nvd_name <> ''"
			packages = query.Packages(additional_fiedls, condition, False)
			TwoGroupsOfRelatedPackages(packages)

	elif args.target == 'cpe':
		query.CPEs()

	elif args.target == 'fstec_products':
		query.FSTECProducts()

	elif args.target in ('used_vendors',  'unused_vendors'):
		reverse_branch_map = query.ReverseIntegralBranchMap()
		products_of_vendors = query.ProductsOfVendors()
		if args.target == 'used_vendors':
			MappedProductsOfVendors(reverse_branch_map, products_of_vendors)
		else:
			VendorsWithNoMappedProducts(reverse_branch_map, products_of_vendors)

	elif args.target == 'refs':
		query.NVDVulImportRefs()

	elif args.target == 'product_marks':
		products = query.ProductNames()
		RateProductNamesPrefixesAndSuffixes(products)

	exit(0)
