#!/usr/bin/env python3
# Copyright (c) 2014-2016  Milos Jakubicek
from __future__ import print_function
from __future__ import unicode_literals

import sys
from os.path import basename
from manatee import Corpus, new_TokenLevel, full_level, MLTStream, setEncoding

try: range = xrange
except NameError: pass

try: binary_stdout = sys.stdout.buffer
except AttributeError: binary_stdout = sys.stdout

G_TAG = '<g/>'

def print_aligned (al1_struc, al2_struc, al1_attr, al2_attr, al1_beg, al1_end,
        al2_beg, al2_end, g1_struc, g2_struc, separate_docs, d):
    al1_begpos = al1_struc.beg (al1_beg)
    al2_begpos = al2_struc.beg (al2_beg)
    al1_endpos = al1_struc.end (al1_end)
    al2_endpos = al2_struc.end (al2_end)
    al1_it = al1_attr.textat (al1_begpos)
    al2_it = al2_attr.textat (al2_begpos)
    g1 = []

    if d[1] < 0:
        d[1] = d[0].num_at_pos(al1_endpos)

    if g1_struc:
        g1_struc_rs = g1_struc.whole()
        g1_pos = g1_struc_rs.find_beg (al1_begpos)
        while g1_pos < al1_endpos and not g1_struc_rs.end():
            g1.append (g1_pos)
            g1_pos = g1_struc_rs.find_beg (g1_pos + 1)
    g2 = []
    if g2_struc:
        g2_struc_rs = g2_struc.whole()
        g2_pos = g2_struc_rs.find_beg (al2_begpos)
        while g2_pos < al2_endpos and not g2_struc_rs.end():
            g2.append (g2_pos)
            g2_pos = g2_struc_rs.find_beg (g2_pos + 1)
    al1_tokens = [al1_it.next() for i in range (al1_endpos - al1_begpos)]
    al2_tokens = [al2_it.next() for i in range (al2_endpos - al2_begpos)]
    for i in range(len(g1)):
        al1_tokens.insert (g1[i] - al1_begpos + i, G_TAG)
    for i in range(len(g2)):
        al2_tokens.insert (g2[i] - al2_begpos + i, G_TAG)

    if separate_docs and al1_endpos > d[0].end(d[1]):
        binary_stdout.write('<doc>')
        d[1] = d[0].num_at_pos(al1_endpos)

    binary_stdout.write(" ".join(al1_tokens).replace(" %s " % G_TAG, "").encode('utf-8'))
    binary_stdout.write("\n")
    binary_stdout.write(" ".join(al2_tokens).replace(" %s " % G_TAG, "").encode('utf-8'))
    binary_stdout.write("\n")

def process_MLTStream (al1_struc, al2_struc, al1_attr, al2_attr, g1_struc,
        g2_struc, map_file, separate_docs, d):
    mlts = full_level (new_TokenLevel (map_file))
    while not mlts.end():
        al1_num = mlts.orgpos()
        al2_num = mlts.newpos()
        if mlts.change_type() == MLTStream.KEEP:
            for i in range(mlts.change_size()): # change_size = change_newsize
               print_aligned (al1_struc,al2_struc, al1_attr, al2_attr,
                              al1_num, al1_num, al2_num, al2_num, g1_struc,
                              g2_struc, separate_docs, d)
               al1_num += 1
               al2_num += 1
        elif mlts.change_type() == MLTStream.MORPH:
            print_aligned (al1_struc, al2_struc, al1_attr, al2_attr, al1_num,
                           al1_num + mlts.change_size() - 1, al2_num,
                           al2_num + mlts.change_newsize() - 1, g1_struc,
                           g2_struc, separate_docs, d)
        mlts.next()

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("""Usage: %s <SRC_CORPUS> <DST_CORPUS> [--sepdoc]
Print aligned structures in corpora using mapping from <SRC_CORPUS> to
<DST_CORPUS>.
Use --sepdoc to print "<doc>" at document boundaries.""" % sys.argv[0],
            file=sys.stderr)
        sys.exit(1)
    al1 = Corpus (sys.argv[1])
    al2 = Corpus (sys.argv[2])
    setEncoding(al1.get_conf("ENCODING"))
    al1_struc = al1.get_struct (al1.get_conf ("ALIGNSTRUCT"))
    al2_struc = al2.get_struct (al2.get_conf ("ALIGNSTRUCT"))
    al1_attr = al1.get_attr (al1.get_conf ("DEFAULTATTR"))
    al2_attr = al2.get_attr (al2.get_conf ("DEFAULTATTR"))
    d = [al1.get_struct(al1.get_conf('DOCSTRUCTURE')), -1]
    separate_docs = '--sepdoc' in sys.argv
    try:
        g1_struc = al1.get_struct ("g")
    except:
        g1_struc = None
    try:
        g2_struc = al2.get_struct ("g")
    except:
        g2_struc = None
    if al1.get_conf('ALIGNDEF'): # m:n
        map_file = "%s/align.%s" % (al1.get_conf("PATH"), basename(sys.argv[2]))
        process_MLTStream (al1_struc, al2_struc, al1_attr, al2_attr, g1_struc,
                g2_struc, map_file, separate_docs, d)
    else: # 1:1
        for i in range(min(al1_struc.size(), al2_struc.size())):
            print_aligned (al1_struc, al2_struc, al1_attr, al2_attr, i, i, i, i,
                    g1_struc, g2_struc, separate_docs, d)

# vim: ts=4 sw=4 sta et sts=4 si cindent tw=80:
