/*
    Copyright Vladimir Kolmogorov vnk@ist.ac.at 2014

    This file is part of SVM.

    SVM is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    SVM is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with SVM.  If not, see <http://www.gnu.org/licenses/>.
*/

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "SVM.h"
#include "SVMutils.h"



//////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////

SVM::SVM(int _d, int _n, double _lambda, double _mu, MaxFn _max_fn, void* _user_arg, bool _zero_lower_bound, int _group_size) :
	d(_d), n0(_n), group_size(_group_size),
	lambda_mu(_lambda/_mu), lambda_mu_inv(_mu/_lambda), mu(_mu),
	max_fn0(_max_fn), zero_lower_bound(_zero_lower_bound), user_arg(_user_arg),
	terms(NULL), current_sum(NULL), buf(1024)
{
	n = ((n0-1)/group_size) + 1;
	w = (double*) buf.Alloc(d*sizeof(double));
	if (group_size > 1)	max_fn_buf = (double*) buf.Alloc((d+1)*sizeof(double));
}

SVM::~SVM()
{
	if (terms)
	{
		int i;
		for (i=0; i<n; i++) delete terms[i];
		delete [] terms;
	}
}


////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

void SVM::max_fn(int i, double* a)
{
	if (group_size == 1) { (*max_fn0)(i, w, a, user_arg); return; }
	i *= group_size;
	int i_last = i + group_size; if (i_last > n0) i_last = n0;
	(*max_fn0)(i, w, a, user_arg);
	for (i++ ; i<i_last; i++)
	{
		(*max_fn0)(i, w, max_fn_buf, user_arg); 
		int k;
		for (k=0; k<=d; k++) a[k] += max_fn_buf[k];
	}
}

////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

double SVM::Evaluate(double* _w)
{
	int i;

	double* tmp = new double[2*d+1];
	double* w_bak = tmp+d+1;
	memcpy(w_bak, w, d*sizeof(double));
	if (w != _w) memcpy(w, _w, d*sizeof(double));

	double v0 = Norm(w, d), v1 = 0;

	for (i=0; i<n; i++)
	{
		max_fn(i, tmp);
		v1 += DotProduct(w, tmp, d) + tmp[d];
	}

	memcpy(w, w_bak, d*sizeof(double));

	delete [] tmp;

	return mu * (v0*lambda_mu/2 + v1);
}










/////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////
////////////////////// Implementation of 'Term' /////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////





SVM::Term::Term(int _d, SVM* _svm, Buffer* _buf, bool maintain_products)
	: d(_d), svm(_svm), buf(_buf), num(0), num_max(0)
{
	current = (double*) buf->Alloc((d+1)*sizeof(double));
	a = NULL;
	last_accessed = NULL;
	products = NULL;
	my_buf = NULL;
	//Allocate(svm->options.cp_max, maintain_products);
	Allocate(4, maintain_products); // start with up to 4 planes per term, then allocate more if necessary
}

void SVM::Term::Allocate(int num_max_new, bool maintain_products)
{
	int num_max_old = num_max;
	double** a_old = a;
	float* last_accessed_old = last_accessed;
	double** products_old = products;
	char* my_buf_old = my_buf;
	num_max = num_max_new;

	int i, my_buf_size = num_max*sizeof(double*) + num_max*sizeof(float);
	if (maintain_products) my_buf_size += num_max*sizeof(double*) + num_max*num_max*sizeof(double);
	my_buf = new char[my_buf_size];

	a = (double**)my_buf;
	for (i=0; i<num_max_old; i++) a[i] = a_old[i];
	for ( ; i<num_max; i++) a[i] = NULL;

	last_accessed = ((float*)(a+num_max));
	memcpy(last_accessed, last_accessed_old, num_max_old*sizeof(int));

	if (maintain_products)
	{
		products = (double**)(last_accessed+num_max);
		for (i=0; i<num_max; i++)
		{
			products[i] = (i==0) ? ((double*)(products+num_max)) : (products[i-1] + num_max);
		}
		int t1, t2;
		for (t1=0; t1<num; t1++)
		for (t2=0; t2<num; t2++)
		{
			products[t1][t2] = products_old[t1][t2];
		}
	}

	if (my_buf_old) delete [] my_buf_old;
}


SVM::Term::~Term()
{
	if (my_buf) delete [] my_buf;
}

bool SVM::Term::isDuplicate(double* x)
{
	int t;

	for (t=0; t<num; t++)
	{
		if (!memcmp(x, a[t], (d+1)*sizeof(double)))
		{
			last_accessed[t] = svm->timestamp;
			return true;
		}
	}
	return false;
}

int SVM::Term::AddPlane(double* x, int cp_max)
{
	int t, t2;

	if (num >= cp_max)
	{
		for (t=0, t2=1; t2<num; t2++)
		{
			if (last_accessed[t] > last_accessed[t2]) t = t2;
		}
	}
	else
	{
		if (num >= num_max)
		{
			int num_max_new = 2*num_max+1; if (num_max_new > cp_max) num_max_new = cp_max;
			Allocate(num_max_new, (products) ? true : false);
		}
		t = num ++;
		if (!a[t]) a[t] = (double*) buf->Alloc((d+1)*sizeof(double));
		svm->total_plane_num ++;
	}
	memcpy(a[t], x, (d+1)*sizeof(double));
	last_accessed[t] = svm->timestamp;

	if (products)
	{
		for (t2=0; t2<num; t2++)
		{
			products[t][t2] = products[t2][t] = NOT_YET_COMPUTED; // DotProduct(a[t], a[t2], d);
		}
	}

	return t;
}

void SVM::Term::DeletePlane(int t)
{
	num --;
	svm->total_plane_num --;
	if (t == num) return;
	double* tmp = a[t]; a[t] = a[num]; a[num] = tmp;
	last_accessed[t] = last_accessed[num];

	if (products)
	{
		int t2;
		for (t2=0; t2<num; t2++)
		{
			products[t][t2] = products[t2][t] = products[num][t2];
		}
		products[t][t] = products[num][num];
	}
}

int SVM::Term::Maximize(double* w)
{
	int t_best, t;
	double v_best;
	for (t=0; t<num; t++)
	{
		double v = DotProduct(a[t], w, d) + a[t][d];
		if (t == 0 || v_best <= v) { v_best = v; t_best = t; }
	}
	return t_best;
}

void SVM::Term::UpdateStats(int t_best)
{
	last_accessed[t_best] = svm->timestamp;
}

void SVM::Term::RemoveUnusedPlanes()
{
	if (svm->timestamp_threshold < 0) return;

	int t;
	for (t=0; t<num; t++)
	{
		if (last_accessed[t] < svm->timestamp_threshold && num > 1)
		{
			DeletePlane(t --);
		}
	}
}
