#!/usr/bin/env python
"""
Attribute-Based Classification 
Perform Multiclass Predicition from binary attributes and evaluates it.
(C) 2012 Christoph Lampert <chl@ist.ac.at>
"""

import sys
from numpy import *

def myfind(bool_array):
    return nonzero(bool_array)[0]

def nameonly(x):
    return x.split('\t')[1]

def loadstr(filename, converter=str):
    return [converter(c.strip()) for c in file(filename).readlines()]

def loaddict(filename, converter=str):
    D = {}
    for line in file(filename).readlines():
        line = line.split(1)
        D[line[0]] = converter(line[1].strip())
    
    return D

# adapt these paths and filenames to match local installation

label_filepattern =  '/media/extern/datasets/Attribute-Based-Classification/%s/all.classid'
attributematrix_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/predicate-matrix-binary.txt'
attributeperimage_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/all-perimage.attributes'
attributeperclass_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/all-perclass.attributes'
mask_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/split%d.mask'

DAP_predictions_metapattern = '/media/extern/datasets/Attribute-Based-Classification/%s/results/DAP-split%d_F%d_C%g_attr%%02d.prob'

create_PDF = True

def evaluate(dataset, splitID=0, C=1., perimage=1, special_SUN_flag=False):
    mask = loadtxt( mask_filepattern % (dataset,splitID), dtype=int)
    Itrn = (mask>0)
    Itst = (mask==0)
    
    L = loadtxt( label_filepattern % dataset, dtype=int)
    L = L - amin(L) # shift to index 0
    train_classes = unique( L[Itrn] )
    test_classes = unique( L[Itst] )
    class_index = dict( (l,i) for i,l in enumerate(test_classes) )
    mtst,mtrn= len(test_classes),len(train_classes)
    
    A = loadtxt( attributematrix_filepattern % dataset ,dtype=float)
    m,k = A.shape # num_classes, num_attributes
    assert(m == mtrn+mtst) # here: check that classes were disjoint
    
    pattern = DAP_predictions_metapattern  % ( dataset, splitID, perimage, C )
    
    P = [ loadtxt(pattern % i, float) for i in xrange(k) ]
    P = array(P).T   # (85,n)
    
    prior = mean(A[train_classes],axis=0) # average attribute value
    prior[prior==0.]=0.5
    prior[prior==1.]=0.5    # suppress degenerated priors
    A = A[test_classes] # (mtst,d)
    
    prob=[]
    for pred in P:
        prob.append( prod(A*pred + (1.-A)*(1.-pred),axis=1)/prod(A*prior+(1.-A)*(1-prior), axis=1) )
    
    MCpred = argmax( prob, axis=1 ) # predict class by MAP
    numexamples = len(MCpred)
    
    confusion=zeros([mtst,mtst])
    for pl,nl in zip(MCpred,L[Itst]):
        try:
            gt = class_index[nl]
            confusion[gt,pl] += 1.
        except:
            pass
    
    for row in confusion:
        row /= sum(row)
    
    if not special_SUN_flag:
        return confusion,asarray(prob),L[Itst],P,mean(diag(confusion))

    assert(dataset == 'sun')
    hierarchy_level1 = '/media/extern/datasets/Attribute-Based-Classification/sun/classes-level1.txt'
    hierarchy_level2 = '/media/extern/datasets/Attribute-Based-Classification/sun/classes-level2.txt'
    
# compute confusion matrix on 3 levels: level1-level1, level2-level2, class-class
    H1 = loadtxt(hierarchy_level1,dtype=int)
    H2 = loadtxt(hierarchy_level2,dtype=int)
    d1,d2,dC = H1.shape[1],H2.shape[1],mtst
    mean_acc1=0.
    mean_acc2=0.
    mean_accC=0.
    for pl,nl in zip(MCpred,L[Itst]):
        pC,gtC = test_classes[pl],nl
        mean_accC += (gtC == pC)
        p1,gt1 = H1[test_classes[pl]],H1[nl] # level 1 hierarchy vector 
        mean_acc1 += (p1*gt1).any() # must match in any (of potentially more than one) clases
        p2,gt2 = H2[test_classes[pl]],H2[nl] # level 2 hierarchy vector
        mean_acc2 += (p2*gt2).any() # must match in any (of potentially more than one) clases
    
    return confusion,array(prob),L[Itst],P,mean_accC/numexamples,mean_acc2/numexamples,mean_acc1/numexamples


# routine to visual the visual results 

classname_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/classes.txt'
attributename_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/attributes.txt'
confusion_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/confusion-%s-split%d-F%d.pdf'
classROC_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/classROC-%s-split%d-F%d.pdf'
attributeROC_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/attributeROC-%s-split%d-F%d.pdf'

def plot_confusion(confusion, dataset, splitID, perimage):
    if not create_PDF:
        return
    
    classnames = loadstr(classname_filepattern % dataset,nameonly)
    L = loadtxt( label_filepattern % dataset, dtype=int)
    L = L - amin(L) # shift to index 0
    
    mask = loadtxt( mask_filepattern % (dataset,splitID), dtype=int)
    test_classes = unique( L[mask==0] )
    d = len(test_classes)
    test_classnames = [classnames[l] for l in test_classes]
    
    from pylab import figure,imshow,clim,xticks,yticks,axis,setp,gray,colorbar,savefig,gca
    fig=figure(figsize=(10,9))
    imshow(confusion,interpolation='nearest',origin='upper')
    clim(0,1)
    xticks(arange(0,d),[c.replace('+',' ') for c in test_classnames],rotation='vertical',fontsize=24)
    yticks(arange(0,d),[c.replace('+',' ') for c in test_classnames],fontsize=24)
    axis([-.5,d-0.5,d-0.5,-.5])
    setp(gca().xaxis.get_major_ticks(), pad=18)
    setp(gca().yaxis.get_major_ticks(), pad=12)
    fig.subplots_adjust(left=0.30)
    fig.subplots_adjust(top=0.98)
    fig.subplots_adjust(right=0.98)
    fig.subplots_adjust(bottom=0.29)
    gray()
    colorbar(shrink=0.79)
    savefig( confusion_filepattern % (dataset, dataset, splitID, perimage) )
    return 

def plot_classROC(P, dataset, splitID, perimage):
    classnames = loadstr(classname_filepattern % dataset,nameonly)
    L = loadtxt( label_filepattern % dataset, dtype=int)
    L = L - amin(L) # shift to index 0
    
    mask = loadtxt( mask_filepattern % (dataset,splitID), dtype=int)
    Itst = (mask==0)
    test_classes = unique( L[Itst] )
    test_classnames = [classnames[l] for l in test_classes]
    
    from roc import roc
    AUC=[]
    CURVE=[]
    for i,c in enumerate(test_classnames):
        class_id = classnames.index(c)
        tp,fp,auc=roc(None,L[Itst]==class_id,  P[:,i] )
        print "AUC: %s %5.3f" % (c,auc)
        AUC.append(100*auc)
        CURVE.append(array([fp,tp]))
    
    print "mean class AUC: %5.1f std %5.1f" % (mean(AUC),std(AUC))
    
    if not create_PDF:
        return

    from pylab import figure,xticks,yticks,axis,setp,gray,colorbar,savefig,gca,clf,plot,legend,xlabel,ylabel
    order = argsort(AUC)[::-1]
    #styles=['-','-','-','-','-','-','-','--','--','--']
    fig = figure(figsize=(9,5))
    for i in order:
        c = test_classnames[i]
        plot(CURVE[i][0], CURVE[i][1], label='%s (AUC: %4.1f)' % (c, AUC[i]), linewidth=3) #,linestyle=styles[i]) 
    
    plot([0,1],[0,1],':',label='random',linewidth=2) 
    legend(loc='lower right')
    xticks([0.0,0.2,0.4,0.6,0.8,1.0], [r'$0$', r'$0.2$',r'$0.4$',r'$0.6$',r'$0.8$',r'$1.0$'], fontsize=18)
    yticks([0.0,0.2,0.4,0.6,0.8,1.0], [r'$0$', r'$0.2$',r'$0.4$',r'$0.6$',r'$0.8$',r'$1.0$'], fontsize=18)
    xlabel('false negative rate',fontsize=18)
    ylabel('true positive rate',fontsize=18)
    fig.subplots_adjust(left=0.1)
    fig.subplots_adjust(top=0.95)
    fig.subplots_adjust(right=0.97)
    fig.subplots_adjust(bottom=0.15)
    savefig( classROC_filepattern % (dataset, dataset, splitID, perimage) )

def plot_attributeROC(P, dataset, splitID, perimage):
    attributenames = loadstr(attributename_filepattern % dataset, nameonly)
    
    mask = loadtxt( mask_filepattern % (dataset,splitID), dtype=int)
    Itst = (mask==0)
    if perimage == 1:
        A = loadtxt( attributeperimage_filepattern % dataset, dtype=float)
    else: # per-class, or from IAP
        A = loadtxt( attributeperclass_filepattern % dataset, dtype=float)
    
    A = A[Itst]
    m,k = A.shape # num_test_examples, num_attributes
    
    from roc import roc
    AUC=[]
    #print A.shape,P.shape
    for i,c in enumerate(attributenames):
        if A[:,i].min() != A[:,i].max():
            _,_,auc=roc(None, A[:,i],  P[:,i])
        else:
            auc = 0. # 'doesn't work' flag
        print "AUC: %d %s %5.3f" % (i,c,auc)
        AUC.append(auc)
    
    fixedAUC = [auc for auc in AUC if auc>0.01] # remove empty entries
    if len(fixedAUC)>0:
       print "mean attr AUC: %5.3f std %5.3f" % (mean(fixedAUC),std(fixedAUC))
    else:
       print "mean attr AUC: %5.3f std %5.3f" % (0,0)
    if not create_PDF:
        return
    
    from pylab import figure,xticks,yticks,setp,gray,colorbar,savefig,gca,clf,plot,legend,xlabel,ylabel,xlim,bar 
    fig = figure(figsize=(19,5))
    plot([0.,k],[.5,.5],'r',linewidth=3,zorder = 1)
    bar(0.2+arange(k), AUC, 0.8, color='b', zorder = 2)
    xlim(0,k)
    xticks(0.5+arange(k), attributenames, fontsize=12, rotation='vertical')
    yticks([0.0,0.25,0.5,0.75,1.0], [r'$0$', r'$0.25$',r'$0.5$',r'$0.75$',r'$1$'], fontsize=18)
    #xlabel('attributes',fontsize=18)
    fig.subplots_adjust(left=0.05)
    fig.subplots_adjust(top=0.98)
    fig.subplots_adjust(right=0.99)
    fig.subplots_adjust(bottom=0.24)
    ylabel('area under ROC curve',fontsize=18)
    savefig( attributeROC_filepattern % (dataset, dataset, splitID, perimage) )

def main():
    import sys
    try:
        dataset = sys.argv[1]
    except IndexError:
        print "Usage: %s dataset [splitID C perimage-flag]" % sys.argv[0]
        raise SystemExit 
    
    try:
        splitID = int(sys.argv[2])
    except IndexError:
        splitID = 0

    try:
        C = float(sys.argv[3])
    except IndexError:
        C = 1.

    try:
        perimage_flag = int(sys.argv[4])
    except IndexError:
        perimage_flag = 0
    
    if dataset != 'sun':
        confusion,prob,L,P,mean_acc = evaluate(dataset,splitID,C,perimage_flag)
    else:
        confusion,prob,L,P,mean_acc,acc2,acc1 = evaluate(dataset,splitID,C,perimage_flag,special_SUN_flag=True)
    
    print "prob: mean class accuracy %g" % (100*mean_acc)
    if dataset == "sun":
        print "prob: mean level 2 accuracy %g" % (100*acc2)
        print "prob: mean level 1 accuracy %g" % (100*acc1)
    
    plot_confusion(confusion, dataset, splitID, perimage_flag) 
    plot_classROC(prob, dataset, splitID, perimage_flag) 
    plot_attributeROC(P, dataset, splitID, perimage_flag)

if __name__ == '__main__':
    main()
