#!/usr/bin/env python3

import argparse
from collections import namedtuple
import ctypes
from io import BytesIO
import json
from math import floor
import logging
from pathlib import Path
import struct
import sys
import tarfile
from tarfile import BLOCKSIZE
from time import time
from typing import List, Tuple, Optional

# "<" prefix means little-endian and no alignment
# order is important! if uint64_t is not first, c++ will use padding bytes to unpack
INDEX_BIN_FORMAT = '<QLL'
INDEX_BIN_SIZE = struct.calcsize(INDEX_BIN_FORMAT)
INDEX_FILE = "index.bin"
# skip the first 40 bytes of the tile header
GRAPHTILE_SKIP_BYTES = struct.calcsize('<Q2f16cQ')
TRAFFIC_HEADER_FORMAT = '<2Q4I'
TRAFFIC_HEADER_SIZE = struct.calcsize(TRAFFIC_HEADER_FORMAT)
TRAFFIC_SPEED_SIZE = struct.calcsize('<Q')
TRAFFIC_VERSION = 3

Bbox = namedtuple("Bbox", "min_x min_y max_x max_y")
TILE_SIZES = {0: 4, 1: 1, 2: 0.25, 3: 0.25}

# hack so ArgumentParser can accept negative numbers
# see https://github.com/valhalla/valhalla/issues/3426
for i, arg in enumerate(sys.argv):
    if not len(arg) > 1:
        continue
    if (arg[0] == '-') and arg[1].isdigit():
        sys.argv[i] = ' ' + arg


class TileHeader(ctypes.Structure):
    """
    Resembles the uint64_t bit field at bytes 40 - 48 of the
    graphtileheader to get the directededgecount_.
    """

    _fields_ = [
        ("nodecount_", ctypes.c_ulonglong, 21),
        ("directededgecount_", ctypes.c_ulonglong, 21),
        ("predictedspeeds_count_", ctypes.c_ulonglong, 21),
        ("spare1_", ctypes.c_ulonglong, 1),
    ]


class TileResolver:
    def __init__(self, path: Path):
        """
        Abstraction so we don't have to care whether we're looking at a tile directory or tar file.

        :param path: path to the tile directory or tar file.
        """
        self.path = path.resolve()
        self._is_tar = path.is_file()
        self._tar_obj: Optional[tarfile.TarFile] = tarfile.open(self.path, "r") if self._is_tar else None

        self.normalized_tile_paths: List[Path] = list()
        self.matched_paths: List[Path] = list()

        # pre-populate the available paths
        if self._is_tar:
            self.normalized_tile_paths = sorted(
                [Path(m.name) for m in self._tar_obj.getmembers() if m.name.endswith('.gph')]
            )
        else:
            self.normalized_tile_paths = sorted(
                p.relative_to(self.path) for p in self.path.rglob('*.gph')
            )

    def __del__(self):
        # close the tar object on GC
        if self._tar_obj:
            self._tar_obj.close()

    def add_to_tar(self, tar: tarfile.TarFile):
        """
        Adds the self.matched_paths to the passed tar file.
        """
        # deduplicate the list (geojson variant might've added dups)
        # since 3.7 python dicts are insertion-ordered, so order is preserved
        for t in list(dict.fromkeys(self.matched_paths)):
            LOGGER.debug(f"Adding tile {t} to the tar file")
            if self._is_tar:
                tar_member = self._tar_obj.getmember(str(t))
                tar.addfile(tar_member, self._tar_obj.extractfile(tar_member.name))
            else:
                tar.add(str(self.path.joinpath(t)), arcname=t)
                tar_member = tar.getmember(str(t))


description = "Builds a tar extract from the tiles in mjolnir.tile_dir to the path specified in mjolnir.tile_extract."

parser = argparse.ArgumentParser(description=description)
parser.add_argument(
    "-c", "--config", help="Absolute or relative path to the Valhalla config JSON.", type=Path
)
parser.add_argument(
    "-i",
    "--inline-config",
    help="Inline JSON config, will override --config JSON if present",
    type=str,
    default='{}',
)
parser.add_argument(
    "-e",
    "--extract-tar",
    help="If specified, will build an extract from an existing tar file at 'mjolnir.tile_extract' and save it to this specified path.",
    type=str,
    default="",
)
parser.add_argument("-O", "--overwrite", help="Overwrites an output tar file", action="store_true")
parser.add_argument(
    "-t", "--with-traffic", help="Flag to add a traffic.tar skeleton", action="store_true", default=False
)
geom_type = parser.add_mutually_exclusive_group()
geom_type.add_argument(
    "-b",
    "--bbox",
    help="If specified, will only archive tiles which are intersecting this bbox. In 'minx,miny,maxx,maxy' format.",
    type=str,
    default="",
)
geom_type.add_argument(
    "-g",
    "--geojson-dir",
    help="Absolute or relative path to directory with .geojson files used "
    "as input to tile intersection. Requires shapely.",
    type=Path,
)
parser.add_argument(
    "-v",
    "--verbosity",
    help="Accumulative verbosity flags; -v: INFO, -vv: DEBUG",
    action='count',
    default=0,
)

# set up the logger basics
LOGGER = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)5s: %(message)s"))
LOGGER.addHandler(handler)


def get_tile_bbox(tile_path_id: str) -> Tuple[float, float, float, float]:
    """Returns a tile's bounding box in minx,miny,maxx,maxy format"""
    try:
        level, tile_idx = [int(x.replace("/", "")) for x in get_tile_level_id(tile_path_id)]
    except ValueError:
        LOGGER.critical(f"Couldn't get level and tile ID for tile {tile_path_id}")
        sys.exit(1)

    tile_size = TILE_SIZES[level]
    row = floor(tile_idx / (360 / tile_size))
    col = tile_idx % (360 / tile_size)

    tile_base_y = (row * tile_size) - 90
    tile_base_x = (col * tile_size) - 180

    return tile_base_x, tile_base_y, tile_base_x + tile_size, tile_base_y + tile_size


def get_tiles_with_geojson(tile_resolver_: TileResolver, geojson_dir: Path):
    """Returns all tile paths intersecting with the GeoJSON (multi)polygons"""
    try:
        from shapely.geometry import Polygon, box
    except ImportError:
        LOGGER.critical(
            "Could not import shapely. Please install shapely or use another download method instead."
        )
        sys.exit(1)

    if not geojson_dir.is_dir() or not len(list(geojson_dir.glob('*.geojson'))) > 0:
        LOGGER.critical(
            f"Geojson directory does not exist or contains no GeoJSON files: {geojson_dir.resolve()}"
        )
        sys.exit(1)

    def get_outer_rings(input_dir: Path) -> List[Polygon]:
        polygons = list()
        for file in input_dir.glob("*.geojson"):
            with open(file) as gj_f:
                geojson = json.load(gj_f)
            for feature in geojson["features"]:
                if feature["geometry"]["type"] == "Polygon":
                    polygons.append(Polygon(feature["geometry"]["coordinates"][0]))
                if feature["geometry"]["type"] == "MultiPolygon":
                    for single_polygon in feature["geometry"]["coordinates"]:
                        polygons.append(Polygon(single_polygon[0]))
        return polygons

    geojson_polys = get_outer_rings(geojson_dir)
    for tile_path in tile_resolver_.normalized_tile_paths:
        tile_bbox = box(*get_tile_bbox(str(tile_path)))
        for poly in geojson_polys:
            if poly.intersects(tile_bbox):
                tile_resolver_.matched_paths.append(tile_path)


def get_tiles_with_bbox(tile_resolver_: TileResolver, bbox_str: str):
    """Returns all tile paths intersecting with the bbox"""
    try:
        bbox = Bbox(*[float(x) for x in bbox_str.split(",")])
    except ValueError:
        LOGGER.critical(f"BBOX {bbox_str} is not a comma-separated string of coordinates.")
        sys.exit(1)

    # validate bbox
    if (bbox.min_x >= bbox.max_x or bbox.min_x < -180 or bbox.max_x > 180) or (
        bbox.min_y >= bbox.max_y or bbox.min_y < -90 or bbox.max_y > 90
    ):
        LOGGER.critical(f"Bbox invalid: {list(bbox)}")
        sys.exit(1)

    for tile_path in tile_resolver_.normalized_tile_paths:
        tile_bbox = Bbox(*get_tile_bbox(str(tile_path)))
        # check if tile_bbox is outside bbox
        if not any(
            [
                tile_bbox.min_x < bbox.min_x and tile_bbox.max_x < bbox.min_x,  # left of bbox
                tile_bbox.min_y < bbox.min_y and tile_bbox.max_y < bbox.min_y,  # below bbox
                tile_bbox.min_x > bbox.max_x and tile_bbox.max_x > bbox.max_x,  # right of bbox
                tile_bbox.min_y > bbox.max_y and tile_bbox.max_y > bbox.max_y,  # above bbox
            ]
        ):
            tile_resolver_.matched_paths.append(tile_path)


def get_tile_level_id(path: str) -> List[str]:
    """Returns both level and tile ID"""
    return path[:-4].split('/', 1)


def get_tile_id(path: str) -> int:
    """Turns a tile path into a numeric GraphId, including the level"""
    level, idx = get_tile_level_id(path)

    return int(level) | (int(idx.replace('/', '')) << 3)


def get_tar_info(name: str, size: int) -> tarfile.TarInfo:
    """Creates and returns a tarinfo object"""
    tarinfo = tarfile.TarInfo(name)
    tarinfo.size = size
    tarinfo.mtime = int(time())
    tarinfo.type = tarfile.REGTYPE

    return tarinfo


def write_index_to_tar(tar_fp_: Path):
    """Loop through all tiles and write the correct index.bin file to the tar"""
    # get the offset and size from the tarred tile members
    index: List[Tuple[int, int, int]] = list()
    with tarfile.open(tar_fp_, 'r|') as tar:
        for member in tar.getmembers():
            if member.name.endswith('.gph'):
                LOGGER.debug(
                    f"Tile {member.name} with offset: {member.offset_data}, size: {member.size}"
                )

                index.append((member.offset_data, get_tile_id(member.name), member.size))

    # write back the actual index info
    with open(tar_fp_, 'r+b') as tar:
        # jump to the data block, index.bin is the first file
        tar.seek(BLOCKSIZE)
        for entry in index:
            tar.write(struct.pack(INDEX_BIN_FORMAT, *entry))


def create_extracts(config_: dict, do_traffic: bool, tile_resolver_: TileResolver, extract_fp: Path):
    """Actually creates the tar ball. Break out of main function for testability."""
    tiles_count = len(tile_resolver_.matched_paths)
    if not tiles_count:
        LOGGER.critical(f"Couldn't find usable tiles in {tile_resolver_.path}")
        sys.exit(1)

    # write the in-memory index file
    index_size = INDEX_BIN_SIZE * tiles_count
    index_fd = BytesIO(b'0' * index_size)
    index_fd.seek(0)

    # first add the index file, then the sorted tiles to the tarfile
    # TODO: come up with a smarter strategy to cluster the tiles in the tar
    extract_fp.parent.mkdir(parents=True, exist_ok=True)
    with tarfile.open(extract_fp, 'w') as tar:
        tar.addfile(get_tar_info(INDEX_FILE, index_size), index_fd)
        tile_resolver_.add_to_tar(tar)

    write_index_to_tar(extract_fp)

    LOGGER.info(f"Finished tarring {tiles_count} tiles to {extract_fp}")

    # exit if no traffic extract wanted
    if not do_traffic:
        index_fd.close()
        sys.exit(0)

    LOGGER.info("Start creating traffic extract...")
    traffic_fp: Path = Path(
        config_["mjolnir"].get("traffic_extract") or extract_fp.parent.joinpath("traffic.tar")
    )

    # we already have the right size of the index file, simply reset it
    index_fd.seek(0)
    with tarfile.open(extract_fp) as tar_in, tarfile.open(traffic_fp, 'w') as tar_traffic:
        # this will let us do seeks
        in_fileobj = tar_in.fileobj

        # add the index file as first data
        tar_traffic.addfile(get_tar_info(INDEX_FILE, index_size), index_fd)
        index_fd.close()

        # loop over all routing tiles and create fixed-size traffic tiles
        # based on the directed edge count
        for tile_in in tar_in.getmembers():
            if not tile_in.name.endswith('.gph'):
                continue
            # jump to the data's offset and skip the uninteresting bytes
            in_fileobj.seek(tile_in.offset_data + GRAPHTILE_SKIP_BYTES)

            # read the appropriate size of bytes from the tar into the TileHeader struct
            tile_header = TileHeader()
            b = BytesIO(in_fileobj.read(ctypes.sizeof(TileHeader)))
            b.readinto(tile_header)
            b.close()

            # create the traffic header
            header_bytes = struct.pack(
                TRAFFIC_HEADER_FORMAT,
                get_tile_id(tile_in.name),  # numerical tile_id
                0,  # timestamp
                tile_header.directededgecount_,  # edge count
                TRAFFIC_VERSION,  # tile version
                0,  # spare2
                0,  # spare3
            )

            # create the traffic tile
            traffic_size = TRAFFIC_SPEED_SIZE * tile_header.directededgecount_
            tar_traffic.addfile(
                get_tar_info(tile_in.name, len(header_bytes) + traffic_size),
                BytesIO(header_bytes + b'\0' * traffic_size),
            )

            LOGGER.debug(
                f"Traffic tile {tile_in.name} has {tile_header.directededgecount_} directed edges"
            )

    write_index_to_tar(traffic_fp)

    LOGGER.info(f"Finished creating the traffic extract at {traffic_fp}")


if __name__ == '__main__':
    args = parser.parse_args()

    # set the right logger level
    if args.verbosity == 0:
        LOGGER.setLevel(logging.CRITICAL)
    elif args.verbosity == 1:
        LOGGER.setLevel(logging.INFO)
    elif args.verbosity >= 2:
        LOGGER.setLevel(logging.DEBUG)

    if not args.config and not args.inline_config:
        LOGGER.critical("No valid config file or inline config used.")
        sys.exit(1)

    config = dict()
    try:
        with open(args.config) as f:
            config = json.load(f)
    except TypeError:
        LOGGER.warning("Only inline-config will be used.")

    # override with inline-config
    config.update(**json.loads(args.inline_config))

    # if output file exists and not --overwrite specified, fail
    tiles_extract_out = args.extract_tar or config["mjolnir"].get("tile_extract")
    if not tiles_extract_out:
        LOGGER.critical("No output file path specified in 'mjolnir.tile_extract' or --extract-tar.")
        sys.exit(1)

    tiles_extract_out = Path(tiles_extract_out)
    if tiles_extract_out.is_file() and not args.overwrite:
        LOGGER.critical(f"File exists. Specify --overwrite to overwrite {tiles_extract_out}")
        sys.exit(1)

    # prefer the tar file to extract from, fall back to the tile dir
    tiles_dir: Path = Path(config["mjolnir"].get("tile_dir", '/dev/null'))
    # optionally use the tile extract to find the tiles
    if args.extract_tar and (tiles_extract_in := Path(config["mjolnir"].get("tile_extract"))).is_file():
        tile_resolver = TileResolver(tiles_extract_in)
        LOGGER.debug("Using tar file to extract tiles")
    # else get and validate the tiles directory
    elif tiles_dir.is_dir():
        tile_resolver = TileResolver(tiles_dir)
        LOGGER.debug("Using graph dir to extract tiles")
    else:
        LOGGER.critical(
            f"'Can't find valid paths for 'mjolnir.tile_dir' or 'mjolnir.tile_extract' in {args.config}"
        )
        sys.exit(1)

    # get the tile paths which intersect with the geom, if any, and write to TileResolver
    if args.bbox:
        get_tiles_with_bbox(tile_resolver, args.bbox)
    elif args.geojson_dir:
        get_tiles_with_geojson(tile_resolver, args.geojson_dir)
    else:
        tile_resolver.matched_paths = tile_resolver.normalized_tile_paths

    create_extracts(config, args.with_traffic, tile_resolver, tiles_extract_out)
