#!/usr/bin/env python
"""
Animals with Attributes Dataset, http://attributes.kyb.tuebingen.mpg.de
Train one binary attribute classifier using all possible features.
Needs "shogun toolbox with python modular interface" for SVM training
(C) 2009-2012 Christoph Lampert <chl@ist.ac.at>
"""

import sys
from numpy import loadtxt, savetxt, fromfile, float32, float64
from numpy import array, asarray, dot, nonzero, sqrt, mean, sign, ones
from platt import SigmoidTrain, SigmoidPredict
from roc import roc

kernel_filepattern = '%s/all.kernel'
#label_filepattern =  '%s/all.classid'
attribute_perimage_filepattern = '%s/all-perimage.attributes'
attribute_perclass_filepattern = '%s/all-perclass.attributes'
mask_filepattern = '%s/split%d.mask'

# load train/test split mask and keep only test part
def create_data(dataset, perimage=1, splitID=0):
    """Load data (kernel and label matrix) based on dataset and splitID"""
    M = loadtxt(mask_filepattern%(dataset, splitID), dtype=int)
    train_index = (M>0)
    test_index = (M==0) # negatives are ignored
    
    if perimage:
        A = loadtxt(attribute_perimage_filepattern % dataset, dtype=int) # image-attribute matrix
    else:
        A = loadtxt(attribute_perclass_filepattern % dataset, dtype=int) # image-attribute matrix
    
    Atrn = A[train_index]
    Atst = A[test_index]
    
    K = fromfile(kernel_filepattern % dataset, dtype=float32) # 3.7GB for AwA
    K.shape = (sqrt(len(K)),sqrt(len(K)))
    K = K[train_index]
    Ktrn = K[:,train_index]
    Ktst = K[:,test_index]
    return Ktrn,Ktst,Atrn,Atst

# train subset of attributes (same kernel, but labels change)
def train_attribute(dataset, attribute_ids=[], splitID=0, perimage=1, C=1.):
    """Train per-attribute binary classifiers: SVM plus Platt scaling"""
    from shogun.Classifier import LibSVM
    from shogun.Kernel import CustomKernel
    from shogun.Features import BinaryLabels
    
    Ktrn,Ktst,Atrn,Atst = create_data(dataset, perimage, splitID)
    ntrn,ntst = Ktst.shape
    if attribute_ids == []:
        attribute_ids = range(Atrn.shape[1]) # all attributes
    
    # we do independent per-attribute training 
    Ip_val = array([ (i%10==9) for i in xrange(ntrn) ], bool)
    Ip_trn = ~Ip_val
    Kp_trn = Ktrn[Ip_trn,:][:,Ip_trn]
    Kp_val = Ktrn[Ip_trn,:][:,Ip_val]
    Kp_tst = Ktst[Ip_trn,:]
    del Ktrn,Ktst
    
    kernel = CustomKernel(Kp_trn)
    svm = LibSVM()
    svm.set_C(C,C)
    svm.set_kernel(kernel)
    
    for attribute_id in attribute_ids:
        print "# attribute ",attribute_id
        Ltrn = 2.*Atrn[:,attribute_id]-1.
        Ltst = 2.*Atst[:,attribute_id]-1. # not used anywhere, just output for convenience evaluation
        Lp_trn = Ltrn[Ip_trn] # label of 90% training samples
        Lp_val = Ltrn[Ip_val] # labels of remaining 10% (for Platt scaling)
        
        #print Lp_trn,Lp_val 
        if Lp_trn.min() == Lp_trn.max():  # only 1 class in training, predict constant output
            Lprior = mean(Ltrn)
            pred = sign(Lprior) * ones(ntst)
            prob = 0.1+0.8*0.5*(Lprior+1.) * ones(ntst)
            savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d.txt' % (dataset, splitID, perimage, C, attribute_id), pred)
            savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d.prob' % (dataset, splitID, perimage, C, attribute_id), prob)
            savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d.labels' % (dataset, splitID, perimage, C, attribute_id), Ltst)
            print '#tst-SVM-perf(const)   ', attribute_id, C, mean((pred*Ltst)>0), max( mean(Ltst>0),mean(Ltst<0) )
            print '#tst-platt-perf(const) ', attribute_id, C, mean((sign(prob-0.5)*Ltst)>0), max( mean(Ltst>0),mean(Ltst<0) )
            continue
        
        labels_trn = BinaryLabels( asarray(Lp_trn,float64) )
        svm.set_labels(labels_trn)
        try:
            svm.train() # should be fast with precomputed kernel (1min for AwA or so)
        except (RuntimeWarning,RuntimeError):    # can't train, e.g. all samples have the same labels
            print >> sys.stderr, "Can't train dataset %s split %d C %g attribute %s perclass %d" % (dataset, splitID, C, attribute_id, perimage)
            continue
        
        bias = svm.get_bias() # SVM bias term
        SVs = svm.get_support_vectors() # indices of support vectors 
        alphas = svm.get_alphas() # contains only the non-zero alphas
        
        # -------------- now apply to validation set and estimate Platt sigmoid parameters --------
        pred = dot(alphas, Kp_val[SVs])+bias # predict SVM scores (should take just a few seconds)
        platt_params = SigmoidTrain(pred, Lp_val)
        prob = SigmoidPredict(pred, platt_params)
        
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-val.txt' % (dataset, splitID, perimage, C, attribute_id), pred)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-val.prob' % (dataset, splitID, perimage, C, attribute_id), prob)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-val.labels' % (dataset, splitID, perimage, C, attribute_id), Lp_val)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-val.platt' % (dataset, splitID, perimage, C, attribute_id), platt_params)
        print '#val-acc(pred,prob,base) ', attribute_id, C, mean(sign(pred) == Lp_val ), mean( sign(prob-0.5) == Lp_val ), max( mean(Lp_val>0),mean(Lp_val<0) )
        if Lp_val.min() != Lp_val.max():
            _,_,auc = roc(None, Lp_val, pred)
            print '#val-roc ', attribute_id, C, auc
        else:
            print '#val-roc ', attribute_id, C, 0.
        
        # --------------------- applying to train images is rather pointless  -----------
        pred = dot(alphas, Kp_trn[SVs]) + bias
        prob = SigmoidPredict(pred, platt_params)
        
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-trn.txt' % (dataset, splitID, perimage, C, attribute_id), pred)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-trn.prob' % (dataset, splitID, perimage, C, attribute_id), prob)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d-trn.labels' % (dataset, splitID, perimage, C, attribute_id), Lp_trn)
        print '#trn-acc(pred,prob,base) ', attribute_id, C, mean( sign(pred) == Lp_trn ), mean( sign(prob-0.5) == Lp_trn ), max( mean(Lp_trn>0),mean(Lp_trn<0) )
        if Lp_trn.min() != Lp_trn.max():
            _,_,auc = roc(None, Lp_trn, pred)
            print '#trn-roc ', attribute_id, C, auc
        else: 
            print '#trn-roc ', attribute_id, C, 0.
        
        # ----------------------------- apply to test class images ------------------
        pred = dot(alphas, Kp_tst[SVs]) + bias
        prob = SigmoidPredict(pred, platt_params)
        
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d.txt' % (dataset, splitID, perimage, C, attribute_id), pred)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d.prob' % (dataset, splitID, perimage, C, attribute_id), prob)
        savetxt('./%s/results/DAP-split%d_F%d_C%g_attr%02d.labels' % (dataset, splitID, perimage, C, attribute_id), Ltst)
        print '#tst-acc(pred,prob,base) ', attribute_id, C, mean( sign(pred) == Ltst ), mean( sign(prob-0.5) == Ltst), max( mean(Ltst>0), mean(Ltst<0) )
        #print mean(Ltst),mean(Lp_trn),mean(Lp_val)
        if Ltst.min() != Ltst.max(): 
            _,_,auc = roc(None, Ltst, pred)
            print '#tst-roc ', attribute_id, C, auc
        else:
            print '#tst-roc ', attribute_id, C, 0.


if __name__ == '__main__':
    try:
        dataset = sys.argv[1]
    except IndexError:
        print >> sys.stderr, "Usage: %s dataset [attributeID/IDs splitID C perimage-flag]" % sys.argv[0]
        raise SystemExit

    try:
        attribute_ids = eval(sys.argv[2])
    except IndexError:
        attribute_ids = []

    if isinstance(attribute_ids,int):
        attribute_ids = [attribute_ids]
    
    try:
        splitID = int(sys.argv[3])
    except IndexError:
        splitID = 0
    try:
        C = float(sys.argv[4])
    except IndexError:
        C = 1.
    try:
        perimage_flag = int(sys.argv[5])
    except IndexError:
        perimage_flag = 0
    
    print "# data ", dataset
    print "# split ", splitID
    print "# C ", C
    if attribute_ids == []:
        print "# attrs ", "all"
    else: 
        print "# attrs ", attribute_ids
    print "# perimage ", perimage_flag
    
    train_attribute(dataset=dataset, attribute_ids=attribute_ids, splitID=splitID, perimage=perimage_flag, C=C)
    print "Done. "

