#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2022 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

from Levenshtein          import distance as Distance
from os                   import cpu_count
from re                   import search
from types                import SimpleNamespace
from ax.sefunctions       import PrepareDir
from cve_manager.conf     import COMMON_SEC
from cve_manager.parallel import Parallel
from cpe_map.defines      import M_PARTIAL, M_BIN_PARTIAL
from cpe_map.init         import Init
from cpe_map.common       import BlankArgParser, NewArgParser, \
	CompatibleGroups, MatchesStr
from cpe_map              import prepare_partial
from cpe_map.control      import TerminateIfAbsent, TerminateOrSkip
from cpe_map.name_conv    import NAME_GROUP, NAME_FORMATTED_SPLITTED, \
	NAME_FORMATTED_SPLITTED_WO_MARK

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Forming a set of arguments

argparser = BlankArgParser()
argparser.add_argument(
	'--bin',
	action='store_true',
	help='Perform matching of binary package names'
	)
argparser.add_argument(
	'--prepare',
	action='store_true',
	help='Prepare supplementary data even if it\'s been prepared before'
	)
argparser = NewArgParser(base=argparser, ptype='p')

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Check that name x and name y differ only in numerical part at the right

def SameAlphaPart(x, y):

	intersec = 0

	for i in range(min(len(x), len(y))):
		if x[i] != y[i]:
			break
		if 'A' <= x[i] <= 'z':
			intersec = i + 1

	if intersec == 0:
		return False

	for name in (x, y):
		for i in range(len(name) - intersec):
			c = name[intersec + i]
			if 'A' <= c <= 'z' or c == '+':
				return False

	return intersec >= (len(x) - intersec) and intersec >= (len(y) - intersec)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get matches between our name and their names

def GetMatchesForName(our_name, their_names, aux):

	matches = {}

	if our_name[NAME_FORMATTED_SPLITTED_WO_MARK]:
		splitted_our_name = our_name[NAME_FORMATTED_SPLITTED_WO_MARK]
		prefix_penalty = aux.PREFIX_PENALTY
	else:
		splitted_our_name = our_name[NAME_FORMATTED_SPLITTED]
		prefix_penalty = 1

	concatenated_our_name = ''.join(splitted_our_name)
	rows = len(splitted_our_name)

	# Calculating additional characteristics of our name
	nchars_in_our_name = len(concatenated_our_name)
	our_name_is_short  = \
		True if nchars_in_our_name < aux.MIN_ACCEPTABLE_LEN else False

	for _their_name, their_name in their_names.items():

		if not CompatibleGroups(our_name[NAME_GROUP], their_name[NAME_GROUP]):
			continue

		this_matches = []

		# Getting list of words of their name
		for splitted_their_name, product_mark_penalty in (
				(their_name[NAME_FORMATTED_SPLITTED], 1),
				(their_name[NAME_FORMATTED_SPLITTED_WO_MARK], 0.99)
				):

			if not splitted_their_name:
				continue

			cols = len(splitted_their_name)

			# There is no point in trying to map the names if:
			# - diff between number of words in those names is too big
			# - at least one of those names consists of too many words
			if abs(rows - cols) > 2 or \
					rows > aux.ILIST_ROWS or cols > aux.ILIST_COLS:
				continue

			concatenated_their_name = ''.join(splitted_their_name)

			# If names differ only in numerical part at the right
			if SameAlphaPart(concatenated_our_name, concatenated_their_name):
				score = 1.0 * prefix_penalty * product_mark_penalty
				this_matches.append(round(score, 3))
				continue

			nchars_in_their_name = len(concatenated_their_name)

			# There is no point in trying to map the names at this point if
			# at least one of them is too short
			if our_name_is_short or nchars_in_their_name < aux.MIN_ACCEPTABLE_LEN:
				continue

			# Forming matrix of combinations of words
			matrix_of_criterions_for_combinations = []
			for word_from_our_name in splitted_our_name:
				vector_of_combinations = []
				for word_from_their_name in splitted_their_name:
					dist = Distance(word_from_our_name, word_from_their_name)
					length = max(len(word_from_our_name), len(word_from_their_name))
					vector_of_combinations.append((dist, length))
				matrix_of_criterions_for_combinations.append(vector_of_combinations)

			# Getting a list of combinations of indexes for the unique combinations
			list_of_combinations = aux.GetListOfListsOfIndexes(rows, cols)

			# Calculating a match percentage
			max_length = max(nchars_in_our_name, nchars_in_their_name)
			max_match_percentage = 0
			for indexes in list_of_combinations:
				match_percentage = 0
				for i, j in indexes:
					dist, length = matrix_of_criterions_for_combinations[i][j]
					# match_percentage += (1.0 - dist / length) * (length / max_length) =>
					match_percentage += (length - dist) * (1.0 / max_length)
				if match_percentage > max_match_percentage:
					max_match_percentage = match_percentage
			if max_match_percentage > 1.0:
				max_match_percentage = 1.0
			# Imposing a penalty for a special prefix
			max_match_percentage = \
				max_match_percentage * prefix_penalty * product_mark_penalty
			# Verifying a match
			if max_match_percentage >= aux.MIN_SCORE:
				this_matches.append(round(max_match_percentage, 3))

		if this_matches:
			matches[_their_name] = max(this_matches)

	return matches

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Checking for partial name match

def GetMatches(our_names, extra_params, send_pipe, _):

	matches = {}

	for _our_name, our_name in our_names.items():
		matches_for_our_name = GetMatchesForName(
			our_name, extra_params.their_names, extra_params.aux)
		if matches_for_our_name:
			matches[_our_name] = MatchesStr(matches_for_our_name)

	send_pipe.send(matches)

	return

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Forming dict of complete package names matches

if __name__ == '__main__':

	# Parsing the args and getting the helper objects
	args = argparser.parse_args()
	type_of_matching = M_BIN_PARTIAL if args.bin else M_PARTIAL
	conf, speaker, mediator = Init(type_of_matching, args)

	# Checking cve-manager home dir
	common_params, = conf.Get([COMMON_SEC])
	home_dir, err = PrepareDir(common_params.get('download'))
	if not home_dir:
		speaker.Status(err=err)
		exit(1)

	# Getting package names
	mx_of_our_names = mediator.QueryPackages(args.packages, args.names,
		cpu_count())
	TerminateIfAbsent(mx_of_our_names)

	# Calculating a max number of words in our names
	n_words_in_our_names_max = 0
	for our_names in mx_of_our_names:
		for our_name in our_names.values():
			if len(our_name[NAME_FORMATTED_SPLITTED]) > n_words_in_our_names_max:
				n_words_in_our_names_max = len(our_name[NAME_FORMATTED_SPLITTED])

	# For every type of selected data source (nvd, fstec)
	for data_source in args.data_sources:

		# Getting product names
		their_names = mediator.QueryProducts(data_source)
		if TerminateOrSkip(their_names):
			continue

		# Calculating a max number of words in their names
		n_words_in_their_names_max = 0
		for their_name in their_names.values():
			if len(their_name[NAME_FORMATTED_SPLITTED]) > n_words_in_their_names_max:
				n_words_in_their_names_max = len(their_name[NAME_FORMATTED_SPLITTED])

		# Preparing supplementary data
		speaker.Say('Preparing auxiliary data')
		aux = prepare_partial.Worker(home_dir)
		success = aux.Prepare(
			n_words_in_our_names_max,
			n_words_in_their_names_max,
			args.prepare
			)
		dim = f'{aux.ILIST_ROWS}x{aux.ILIST_COLS}'
		if not success:
			speaker.Status(err=dim)
			exit(1)
		msg = f'Got aux data for pairs that consist of {dim} words or less'
		speaker.Status(success=msg)

		# Gathering extra params for GetMatches func
		extra_params = SimpleNamespace(their_names=their_names, aux=aux)

		# Running multiple processes of matching
		speaker.Op(data_source, mx_of_our_names, their_names)
		matches, ok = Parallel(GetMatches, mx_of_our_names, extra_params)
		if not ok:
			speaker.Status(err='Some process has terminated with an error')
			exit(1)
		speaker.Status(matches=matches)

		# Updating the table with the results
		if not mediator.SendMatches(data_source, matches):
			exit(1)

	exit(0)
