#!/usr/bin/python
# Copyright (c) 2014-2015  Milos Jakubicek

import sys
from os.path import basename
from manatee import Corpus, new_TokenLevel, full_level, MLTStream

G_TAG = '<g/>'

def print_aligned (al1_struc, al2_struc, al1_attr, al2_attr, al1_beg, al1_end,
                   al2_beg, al2_end, g1_struc, g2_struc):
    al1_begpos = al1_struc.beg (al1_beg)
    al2_begpos = al2_struc.beg (al2_beg)
    al1_endpos = al1_struc.end (al1_end)
    al2_endpos = al2_struc.end (al2_end)
    al1_it = al1_attr.textat (al1_begpos)
    al2_it = al2_attr.textat (al2_begpos)
    g1 = []
    if g1_struc:
        g1_struc_rs = g1_struc.whole()
        g1_pos = g1_struc_rs.find_beg (al1_begpos)
        while g1_pos < al1_endpos:
            g1.append (g1_pos)
            g1_pos = g1_struc_rs.find_beg (g1_pos + 1)
    g2 = []
    if g2_struc:
        g2_struc_rs = g2_struc.whole()
        g2_pos = g2_struc_rs.find_beg (al2_begpos)
        while g2_pos < al2_endpos:
            g2.append (g2_pos)
            g2_pos = g2_struc_rs.find_beg (g2_pos + 1)
    al1_tokens = [al1_it.next() for i in xrange (al1_endpos - al1_begpos)]
    al2_tokens = [al2_it.next() for i in xrange (al2_endpos - al2_begpos)]
    for i in xrange(len(g1)):
        al1_tokens.insert (g1[i] - al1_begpos + i, G_TAG)
    for i in xrange(len(g2)):
        al2_tokens.insert (g2[i] - al2_begpos + i, G_TAG)
    print " ".join(al1_tokens).replace(" %s " % G_TAG, "")
    print " ".join(al2_tokens).replace(" %s " % G_TAG, "")

def process_MLTStream (al1_struc, al2_struc, al1_attr, al2_attr, g1_struc, g2_struc, map_file):
    mlts = full_level (new_TokenLevel (map_file))
    while not mlts.end():
        al1_num = mlts.orgpos()
        al2_num = mlts.newpos()
        if mlts.change_type() == MLTStream.KEEP:
            for i in xrange(mlts.change_size()): # change_size = change_newsize
               print_aligned (al1_struc,al2_struc, al1_attr, al2_attr,
                              al1_num, al1_num, al2_num, al2_num, g1_struc,
                              g2_struc)
               al1_num += 1
               al2_num += 1
        elif mlts.change_type() == MLTStream.MORPH:
            print_aligned (al1_struc, al2_struc, al1_attr, al2_attr, al1_num,
                           al1_num + mlts.change_size(), al2_num,
                           al2_num + mlts.change_newsize(), g1_struc, g2_struc)
        mlts.next()

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print """Usage: %s <SRC_CORPUS> <DST_CORPUS>
Print aligned structures in corpora using mapping from <SRC_CORPUS> to
<DST_CORPUS>.""" % sys.argv[0]
        sys.exit(1)
    al1 = Corpus (sys.argv[1])
    al2 = Corpus (sys.argv[2])
    al1_struc = al1.get_struct (al1.get_conf ("ALIGNSTRUCT"))
    al2_struc = al2.get_struct (al2.get_conf ("ALIGNSTRUCT"))
    al1_attr = al1.get_attr (al1.get_conf ("DEFAULTATTR"))
    al2_attr = al2.get_attr (al2.get_conf ("DEFAULTATTR"))
    try:
        g1_struc = al1.get_struct ("g")
    except:
        g1_struc = None
    try:
        g2_struc = al2.get_struct ("g")
    except:
        g2_struc = None
    if al1.get_conf('ALIGNDEF'): # m:n
        map_file = "%s/align.%s" % (al1.get_conf("PATH"), basename(sys.argv[2]))
        process_MLTStream (al1_struc, al2_struc, al1_attr, al2_attr, g1_struc, g2_struc, map_file)
    else: # 1:1
        for i in xrange(min(al1_struc.size(), al2_struc.size())):
            print_aligned (al1_struc, al2_struc, al1_attr, al2_attr, i, i, i, i, g1_struc, g2_struc)

# vim: ts=4 sw=4 sta et sts=4 si cindent tw=80:
