#!/usr/bin/env python
"""
Animals with Attributes Dataset, http://attributes.kyb.tuebingen.mpg.de
Train one binary attribute classifier using all possible features.
Needs "shogun toolbox with python modular interface" for SVM training
(C) 2009-2012 Christoph Lampert <chl@ist.ac.at>
"""

import sys
from numpy import loadtxt, savetxt, fromfile, float32, float64
from numpy import array, asarray, dot, nonzero, sqrt, mean, sign, ones, unique
from platt import SigmoidTrain, SigmoidPredict
from roc import roc

def myfind(bool_array):
    """Mimics matlab or pylab's 'find': return all indices for which a condition is true"""
    return nonzero(bool_array)[0]

# adapt these paths and filenames to match local installation

kernel_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/all.kernel'
label_filepattern =  '/media/extern/datasets/Attribute-Based-Classification/%s/all.classid'
mask_filepattern = '/media/extern/datasets/Attribute-Based-Classification/%s/split%d.mask'

# load train/test split mask and keep only test part
def create_data(dataset, splitID=0):
    """Load data (kernel and label matrix) based on dataset and splitID"""
    mask = loadtxt(mask_filepattern%(dataset, splitID), dtype=int)
    train_index = (mask>0)
    test_index = (mask==0) # negatives are ignored
    
    M = loadtxt(label_filepattern % dataset, dtype=int) # multi-class labels 
    M -= min(M) # make 0-based index
    Mtrn = M[train_index]
    Mtst = M[test_index]
    
    K = fromfile(kernel_filepattern % dataset, dtype=float32) # 3.7GB for AwA
    K.shape = (sqrt(len(K)),sqrt(len(K)))
    K = K[train_index]
    Ktrn = K[:,train_index]
    Ktst = K[:,test_index]
    return Ktrn,Ktst,Mtrn,Mtst


def train_IAP(dataset, class_ids=[], splitID=0, C=1.):
    """Train per-class binary classifiers: one-vs-rest SVM plus Platt scaling"""
    from shogun.Classifier import LibSVM
    from shogun.Kernel import CustomKernel
    from shogun.Features import BinaryLabels
    
    Ktrn,Ktst,Mtrn,Mtst = create_data(dataset, splitID)
    train_classes = unique(Mtrn)
    ntrn,ntst = Ktst.shape
    if ntrn < 2:
        print >> sys.stderr, "Error: can't train with only 1 class in training set (ignoring)"
        #raise SystemExit
    
    # use 90% of data for actual train, remaining 10% for platt scaling coeffs
    Ip_val = array([ (i%10==9) for i in xrange(ntrn) ], bool)
    Ip_trn = ~Ip_val
    Kp_trn = Ktrn[Ip_trn,:][:,Ip_trn]
    Kp_val = Ktrn[Ip_trn,:][:,Ip_val]
    Kp_tst = Ktst[Ip_trn,:]
    del Ktrn,Ktst
    
    kernel = CustomKernel(Kp_trn)
    svm = LibSVM()
    svm.set_C(C,C)
    svm.set_kernel(kernel)

    if (class_ids == None) or (class_ids == []) or (class_ids == "all"):
        class_ids = train_classes
    
    for class_id in class_ids:
        print
        print "# class ",class_id
        Ltrn = 2.*(Mtrn == class_id)-1.
        Lp_trn = Ltrn[Ip_trn] # 90% for training 
        Lp_val = Ltrn[Ip_val] # remaining 10% for platt scaling 
        
        if Lp_trn.min() == Lp_trn.max():  # only 1 class in training, can't train 1vsRest
            print >> sys.stderr, 'Class not in training set'
            continue  # this can happen if we specify a non-existing class (which can be convenient)
            
        labels_trn = BinaryLabels( asarray(Lp_trn,float64) )
        svm.set_labels(labels_trn)
        try:
            svm.train()
        except (RuntimeWarning,RuntimeError):    # can't train, e.g. all samples have the same labels
            print >> sys.stderr, "Can't train dataset %s split %d C %g class %d" % (dataset, splitID, C, class_id)
            continue
        
        bias = svm.get_bias()
        SVs = svm.get_support_vectors()
        alphas = svm.get_alphas() # contains only the non-zero alphas
        
        # -------------- now apply to validation set and estimate Platt sigmoid parameters --------
        pred = dot(alphas, Kp_val[SVs])+bias # predict SVM scores (should take just a few seconds)
        platt_params = SigmoidTrain(pred, Lp_val)
        prob = SigmoidPredict(pred, platt_params)
        
        savetxt('./%s/results/IAP-split%d_C%g_class%02d-val.txt' % (dataset, splitID,  C, class_id), pred)
        savetxt('./%s/results/IAP-split%d_C%g_class%02d-val.prob' % (dataset, splitID,  C, class_id), prob)
        savetxt('./%s/results/IAP-split%d_C%g_class%02d-val.labels' % (dataset, splitID,  C, class_id), Lp_val)
        savetxt('./%s/results/IAP-split%d_C%g_class%02d-val.platt' % (dataset, splitID,  C, class_id), platt_params)
        print '#val-acc(pred,prob,base) ', class_id, C, mean(sign(pred) == Lp_val ), mean( sign(prob-0.5) == Lp_val ), max( mean(Lp_val>0),mean(Lp_val<0) )
        _,_,auc = roc(None, Lp_val, pred)
        print '#val-roc ', class_id, C, auc
       
        # --------------------- applying to train images is rather pointless  -----------
        apply_to_train=False
        if apply_to_train:
            pred = dot(alphas, Kp_trn[SVs])+bias
            prob = SigmoidPredict(pred, platt_params)
            
            savetxt('./%s/results/IAP-split%d_C%g_class%02d-trn.txt' % (dataset, splitID,  C, class_id), pred)
            savetxt('./%s/results/IAP-split%d_C%g_class%02d-trn.prob' % (dataset, splitID,  C, class_id), prob)
            savetxt('./%s/results/IAP-split%d_C%g_class%02d-trn.labels' % (dataset, splitID,  C, class_id), Lp_trn)
            print '#trn-acc(pred,prob,base) ', class_id, C, mean( sign(pred) == Lp_trn ), mean( sign(prob-0.5) == Lp_trn ), max( mean(Lp_trn>0),mean(Lp_trn<0) )
            _,_,auc = roc(None, Lp_trn, pred)
            print '#trn-roc ', class_id, C, auc
        
        # ----------------------------- apply to test class images ------------------
        pred = dot(alphas, Kp_tst[SVs])+bias
        prob = SigmoidPredict(pred, platt_params)
        
        savetxt('./%s/results/IAP-split%d_C%g_class%02d.txt' % (dataset, splitID,  C, class_id), pred)
        savetxt('./%s/results/IAP-split%d_C%g_class%02d.prob' % (dataset, splitID,  C, class_id), prob)
        print './%s/results/IAP-split%d_C%g_class%02d.prob' % (dataset, splitID,  C, class_id)
        #print '#tst-roc -- does not apply'



if __name__ == '__main__':
    try:
        dataset = sys.argv[1]
    except IndexError:
        print >> sys.stderr, "Usage: %s dataset [classes splitID C ]" % sys.argv[0]
        raise SystemExit

    try:
        class_ids = eval(sys.argv[2])
    except IndexError:
        class_ids = []

    if isinstance(class_ids,int):
        class_ids = [class_ids]
    
    try:
        splitID = int(sys.argv[3])
    except IndexError:
        splitID = 0
    try:
        C = float(sys.argv[4])
    except IndexError:
        C = 1.
    
    print "# data ", dataset
    print "# split ", splitID
    print "# C ", C
    if class_ids == []:
        print "# classes ", "all"
    else: 
        print "# classes ", class_ids
   
    train_IAP(dataset=dataset, class_ids=class_ids, splitID=splitID, C=C)
    print "Done. "

