#!/usr/bin/env python3
# Copyright (c) 2015-2016  Milos Jakubicek
from __future__ import print_function
from __future__ import unicode_literals

import sys, array
from os.path import basename
from os.path import join as pathjoin
from struct import pack

try: range = xrange
except NameError: pass

try: stdout = sys.stdout.buffer
except AttributeError: stdout = sys.stdout

from manatee import Corpus, new_TokenLevel, full_level, MLTStream

def process_MLTStream (al1_struc, al2_struc, map_file):
    mlts = full_level (new_TokenLevel (map_file))
    while not mlts.end():
        al1_num = mlts.orgpos()
        al2_num = mlts.newpos()
        if mlts.change_type() == MLTStream.KEEP:
            for i in range(mlts.change_size()): # change_size = change_newsize
                al1_begpos = al1_struc.beg (al1_num)
                al1_endpos = al1_struc.end (al1_num)
                al2_begpos = al2_struc.beg (al2_num)
                al2_endpos = al2_struc.end (al2_num)
                yield (al1_begpos, al1_endpos, al2_begpos, al2_endpos)
                al1_num += 1
                al2_num += 1
        elif mlts.change_type() == MLTStream.MORPH:
            al1_begpos = al1_struc.beg (al1_num)
            al1_endpos = al1_struc.end (al1_num + mlts.change_size() - 1)
            al2_begpos = al2_struc.beg (al2_num)
            al2_endpos = al2_struc.end (al2_num + mlts.change_newsize() - 1)
            yield (al1_begpos, al1_endpos, al2_begpos, al2_endpos)
        try:
            mlts.next()
        except:
            break

def process_OneToOne (al1_struc, al2_struc):
    for i in range(min(al1_struc.size(), al2_struc.size())):
        al1_begpos = al1_struc.beg (i)
        al1_endpos = al1_struc.end (i)
        al2_begpos = al2_struc.beg (i)
        al2_endpos = al2_struc.end (i)
        yield (al1_begpos, al1_endpos, al2_begpos, al2_endpos)

def par2tokens (al1, al2, al1_attrname, al2_attrname, multi_align, map_file=None):
    al1_struc = al1.get_struct (al1.get_conf ("ALIGNSTRUCT"))
    al2_struc = al2.get_struct (al2.get_conf ("ALIGNSTRUCT"))
    al1_attr = al1.get_attr (al1_attrname)
    al2_attr = al2.get_attr (al2_attrname)
    freqs1 = [0] * al1_attr.id_range()
    freqs2 = [0] * al2_attr.id_range()
    bin_ids = [pack("I", i) for i in range(max(len(freqs1), len(freqs2)))]
    one = pack("I", 1)
    if map_file:
        ranges = process_MLTStream(al1_struc, al2_struc, map_file)
    else:
        ranges = process_OneToOne(al1_struc, al2_struc)
    for al1_begpos, al1_endpos, al2_begpos, al2_endpos in ranges:
        al1_it = al1_attr.posat (al1_begpos)
        al2_it = al2_attr.posat (al2_begpos)
        al1_tokens = [al1_it.next() for i in range (al1_endpos - al1_begpos)]
        al2_tokens = [al2_it.next() for i in range (al2_endpos - al2_begpos)]
        if not multi_align:
            al1_tokens = set(al1_tokens)
            al2_tokens = set(al2_tokens)
        for al1_token in al1_tokens:
            freqs1 [al1_token] += 1
            for al2_token in al2_tokens:
                stdout.write (bin_ids[al1_token])
                stdout.write (bin_ids[al2_token])
                stdout.write (one)
        for al2_token in al2_tokens:
            freqs2 [al2_token] += 1
    freqs1 = array.array ("L", freqs1)
    freqs2 = array.array ("L", freqs2)
    path1 = al1.get_conf("PATH") + al1_attrname + ".align." \
            + al2.get_conffile() + ".frq"
    path2 = al2.get_conf("PATH") + al2_attrname + ".align." \
            + al1.get_conffile() + ".frq"
    with open(path1, "wb") as f1:
        freqs1.tofile(f1)
    with open(path2, "wb") as f2:
        freqs2.tofile(f2)

if __name__ == "__main__":
    if len(sys.argv) < 5:
        print("""Usage: %s <SRC_CORPUS> <DST_CORPUS> <SRC_ATTR> <DST_ATTR>
Extracts aligned token IDs on <SRC_ATTR> and <DST_ATTR> attribute from two aligned
corpora using mapping from <SRC_CORPUS> to <DST_CORPUS>.
Prints token IDs to stdout and stores frequencies of IDs in alignments to
<SRC_CORPUS_PATH>/<SRC_ATTR>.align.<DST_CORPUS_NAME>.frq and
<DST_CORPUS_PATH>/<DST_ATTR>.align.<SRC_CORPUS_NAME>.frq""" \
    % sys.argv[0])
        sys.exit(1)
    multi_align = False
    if sys.argv[1] == "--multi-align":
        multi_align = True
        sys.argv.pop(1)
    al1 = Corpus (sys.argv[1])
    al2 = Corpus (sys.argv[2])
    if al1.get_conf('ALIGNDEF'): # m:n
        map_file = pathjoin(al1.get_conf("PATH"), 'align.' + basename(sys.argv[2]))
        par2tokens(al1, al2, sys.argv[3], sys.argv[4], multi_align, map_file)
    else: # 1:1
        par2tokens(al1, al2, sys.argv[3], sys.argv[4], multi_align)

# vim: ts=4 sw=4 sta et sts=4 si cindent tw=80:
