#!/usr/bin/env python3

# Copyright (C) 2025 Volkov Alexey
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from alt_workstation_10_11_upgrade.upgrade import run_upgrade
from alt_workstation_10_11_upgrade.switch import run_switch

import argparse
import logging
import logging.handlers

def parse_args():
	parser = argparse.ArgumentParser(
		description="Performs upgrade and switch sequentially\n"\
		            "if -u and -s are not specified."
	)
	parser.add_argument(
		"-y", "--yes",
		action="store_true",
		help="Assume Yes to all queries and do not prompt"
	)
	group = parser.add_mutually_exclusive_group(required=False)
	group.add_argument(
		"-u", "--upgrade",
		action="store_true",
		help="Only upgrade the system to P11"
	)
	group.add_argument(
		"-s", "--switch",
		action="store_true",
		help="Only apply switch to upgraded system"
	)
	return parser.parse_args()

class LevelSpecificFormatter(logging.Formatter):
	def __init__(self, fmt_dict: dict, datefmt=None):
		self.fmt_dict = fmt_dict
		self.datefmt = datefmt
		self.formatters = {
			level: logging.Formatter(fmt, datefmt) 
			for level, fmt in fmt_dict.items()
		}
		default_fmt = fmt_dict.get('default', '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
		self.default_formatter = logging.Formatter(default_fmt, datefmt)

	def format(self, record):
		formatter = self.formatters.get(record.levelno, self.default_formatter)
		return formatter.format(record)

def setup_logging():
	root_logger = logging.getLogger()
	root_logger.setLevel(logging.DEBUG)

	fmt_dict = {
		logging.INFO: '%(message)s',
		logging.DEBUG: '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
		logging.WARNING: '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
		logging.ERROR: '%(asctime)s - %(name)s - %(levelname)s - %(message)s\nFile: %(filename)s, line %(lineno)d',
		logging.CRITICAL: '%(asctime)s - %(name)s - %(levelname)s - %(message)s\nFile: %(filename)s, line %(lineno)d',
		'default': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
	}

	formatter = LevelSpecificFormatter(fmt_dict, datefmt='%Y-%m-%d %H:%M:%S')

	file_handler = logging.handlers.RotatingFileHandler(
		'/var/log/alt-workstation-upgrade.log', maxBytes=5_242_880, backupCount=3, encoding='utf-8'
	)
	file_handler.setLevel(logging.DEBUG)
	file_handler.setFormatter(formatter)

	console_handler = logging.StreamHandler()
	console_handler.setLevel(logging.INFO)
	console_handler.setFormatter(logging.Formatter('%(message)s'))

	root_logger.addHandler(file_handler)
	root_logger.addHandler(console_handler)

def print_contact_the_support(stage: str):
	logging.getLogger(__name__).critical(f"{stage} is failed, contact support to resolve your problem.")

def upgrade(only_upgrade: bool, assume_yes=False) -> bool:
	success = run_upgrade(only_upgrade, assume_yes)
	if not success:
		print_contact_the_support("Upgrade to P11")
		exit(1)

	return success

def switch(is_upgraded, assume_yes=False) -> bool:
	success = run_switch(is_upgraded, assume_yes)
	if not success:
		print_contact_the_support("Switch from Mate to GNOME")
		exit(1)

	return success

def main():
	setup_logging()

	logger = logging.getLogger(__name__)
	args = parse_args()
	assume_yes = False

	if args.yes:
		assume_yes = True

	if args.upgrade:
		logger.info("Only upgrading to P11")
		upgrade(True, assume_yes=assume_yes)
	elif args.switch:
		logger.info("Only switching from MATE to GNOME")
		switch(False, assume_yes=assume_yes)
	else:
		switch(upgrade(False, assume_yes), assume_yes)

if __name__ == "__main__":
	main()
