#!/usr/bin/python
# Copyright (c) 2015  Milos Jakubicek

import sys, array
from os.path import basename
from os.path import join as pathjoin
from manatee import Corpus, new_TokenLevel, full_level, MLTStream
from struct import pack

def process_MLTStream (al1_struc, al2_struc, map_file):
    mlts = full_level (new_TokenLevel (map_file))
    while not mlts.end():
        al1_num = mlts.orgpos()
        al2_num = mlts.newpos()
        if mlts.change_type() == MLTStream.KEEP:
            for i in xrange(mlts.change_size()): # change_size = change_newsize
                al1_begpos = al1_struc.beg (al1_num)
                al1_endpos = al1_struc.end (al1_num)
                al2_begpos = al2_struc.beg (al2_num)
                al2_endpos = al2_struc.end (al2_num)
                yield (al1_begpos, al1_endpos, al2_begpos, al2_endpos)
                al1_num += 1
                al2_num += 1
        elif mlts.change_type() == MLTStream.MORPH:
            al1_begpos = al1_struc.beg (al1_num)
            al1_endpos = al1_struc.end (al1_num + mlts.change_size())
            al2_begpos = al2_struc.beg (al2_num)
            al2_endpos = al2_struc.end (al2_num + mlts.change_newsize())
            yield (al1_begpos, al1_endpos, al2_begpos, al2_endpos)
        try:
            mlts.next()
        except:
            break

def process_OneToOne (al1_struc, al2_struc):
    for i in xrange(min(al1_struc.size(), al2_struc.size())):
        al1_begpos = al1_struc.beg (i)
        al1_endpos = al1_struc.end (i)
        al2_begpos = al2_struc.beg (i)
        al2_endpos = al2_struc.end (i)
        yield (al1_begpos, al1_endpos, al2_begpos, al2_endpos)

def par2tokens (al1, al2, al1_attrname, al2_attrname, map_file=None):
    al1_struc = al1.get_struct (al1.get_conf ("ALIGNSTRUCT"))
    al2_struc = al2.get_struct (al2.get_conf ("ALIGNSTRUCT"))
    al1_attr = al1.get_attr (al1_attrname)
    al2_attr = al2.get_attr (al2_attrname)
    freqs1 = [0] * al1_attr.id_range()
    freqs2 = [0] * al2_attr.id_range()
    bin_ids = [pack("I", i) for i in xrange(max(len(freqs1), len(freqs2)))]
    one = pack("I", 1)
    if map_file:
        ranges = process_MLTStream(al1_struc, al2_struc, map_file)
    else:
        ranges = process_OneToOne(al1_struc, al2_struc)
    for al1_begpos, al1_endpos, al2_begpos, al2_endpos in ranges:
        al1_it = al1_attr.posat (al1_begpos)
        al2_it = al2_attr.posat (al2_begpos)
        al1_tokens = [al1_it.next() for i in xrange (al1_endpos - al1_begpos)]
        al2_tokens = [al2_it.next() for i in xrange (al2_endpos - al2_begpos)]
        for al1_token in al1_tokens:
            freqs1 [al1_token] += 1
            for al2_token in al2_tokens:
                sys.stdout.write (bin_ids[al1_token])
                sys.stdout.write (bin_ids[al2_token])
                sys.stdout.write (one)
        for al2_token in al2_tokens:
            freqs2 [al2_token] += 1
    freqs1 = array.array ("L", freqs1)
    freqs2 = array.array ("L", freqs2)
    path1 = al1.get_conf("PATH") + al1_attrname + ".align." \
            + al2.get_conffile() + ".frq"
    path2 = al2.get_conf("PATH") + al2_attrname + ".align." \
            + al1.get_conffile() + ".frq"
    freqs1.tofile(open(path1, "w"))
    freqs2.tofile(open(path2, "w"))

if __name__ == "__main__":
    if len(sys.argv) < 5:
        print """Usage: %s <SRC_CORPUS> <DST_CORPUS> <SRC_ATTR> <DST_ATTR>
Extracts aligned token IDs on <SRC_ATTR> and <DST_ATTR> attribute from two aligned
corpora using mapping from <SRC_CORPUS> to <DST_CORPUS>.
Prints token IDs to stdout and stores frequencies of IDs in alignments to
<SRC_CORPUS_PATH>/<SRC_ATTR>.align.<DST_CORPUS_NAME>.frq and
<DST_CORPUS_PATH>/<DST_ATTR>.align.<SRC_CORPUS_NAME>.frq""" \
    % sys.argv[0]
        sys.exit(1)
    al1 = Corpus (sys.argv[1])
    al2 = Corpus (sys.argv[2])
    if al1.get_conf('ALIGNDEF'): # m:n
        map_file = pathjoin(al1.get_conf("PATH"), 'align.' + basename(sys.argv[2]))
        par2tokens(al1, al2, sys.argv[3], sys.argv[4], map_file)
    else: # 1:1
        par2tokens(al1, al2, sys.argv[3], sys.argv[4])

# vim: ts=4 sw=4 sta et sts=4 si cindent tw=80:
