/*
    Copyright Vladimir Kolmogorov vnk@ist.ac.at 2014

    This file is part of SVM.

    SVM is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    SVM is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with SVM.  If not, see <http://www.gnu.org/licenses/>.
*/


#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "SVM.h"
#include "SVMutils.h"
#include "timer.h"


void SVM::AddCuttingPlane(int i, double* a)
{
	if (options.cp_max <= 0) return;
	if (terms[i]->isDuplicate(a)) return;
	terms[i]->AddPlane(a, options.cp_max);
}




void SVM::InitSolver()
{
	int i;

	terms = new Term*[n];
	current_sum = (double*) buf.Alloc((d+1)*sizeof(double));
	SetZero(current_sum, d+1);
	SetZero(w, d);

	total_plane_num = 0;
	timestamp = 0;
	timestamp_threshold = -1;

	for (i=0; i<n; i++)
	{
		terms[i] = new Term(d, this, &buf, (options.kernel_max > 1) ? true : false);
		double* current = terms[i]->current;
		if (zero_lower_bound)
		{
			SetZero(current, d+1);
		}
		else
		{
			max_fn(i, current);
			Add(current_sum, current, d+1); 
		}
		AddCuttingPlane(i, current);
	}
}


double* SVM::Solve()
{
	time_start = get_time();

	if (!terms) InitSolver();
	int _i, i, k;
	
	Multiply(w, current_sum, -lambda_mu_inv, d);
	lower_bound_last = (-Norm(w, d)*lambda_mu/2 + current_sum[d])*mu;


	int approx_max = (options.cp_max <= 0) ? 0 : options.approx_max;
	int avg_num; // maintain 'avg_num' averaged vectors
	switch (options.avg_flag)
	{
		case 0: avg_num = 0; break;
		case 1:
		case 2: avg_num = 1; break;
		default: avg_num = 2; break;
	}

	int vec_size = (d+1)*sizeof(double);
	int alloc_size = (1+2*avg_num)*vec_size;
	if (options.randomize_method >= 1 || options.randomize_method <= 3)	alloc_size += n*sizeof(int);
	char* _buf = new char[alloc_size];
	char* _buf0 = _buf;

	double* current_new = (double*) _buf; _buf += vec_size;
	double* avg[2] = { NULL, NULL };
	for (i=0; i<avg_num; i++)
	{
		avg[i] = (double*) _buf; _buf += vec_size; memcpy(avg[i], current_sum, vec_size);
	}
	int* permutation = NULL;
	if (options.randomize_method >= 1 && options.randomize_method <= 3)	{ permutation = (int*) _buf; _buf += n*sizeof(int); }
	double* avg_buf = (double*) _buf;


	int k_avg[2] = { 0, 0 };
	callback_time = 0;

	if (options.randomize_method == 1) generate_permutation(permutation, n);

	for (iter=total_pass=0; iter<options.iter_max; iter++)
	{
		// recompute current_sum every 10 iterations for numerical stability.
		// Just in case, the extra runtime should be negligible.
		// For experiments in the paper this was not used (numerical stability was not an issue)
		if (iter > 0 && (iter % 10) == 0)
		{
			SetZero(current_sum, d+1);
			for (i=0; i<n; i++) Add(current_sum, terms[i]->current, d+1);
		}

		////////////////////////////////////////////////////////////////////////////////////

		timestamp = (float)(((int)timestamp) + 1); // When a plane is accessed, it is marked with 'timestamp'.
		                                           // Throughout the outer iteration, this counter will be gradually
		                                           // increased from 'iter+1' to 'iter+1.5', so that we
		                                           // (1) we can distinguish between planes added in the same iteration (when removing the oldest plane), and
		                                           // (2) we can easily determine whether a plane has been active during the last 'cp_inactive_iter_max' iterations
		if (options.cp_inactive_iter_max > 0) timestamp_threshold = timestamp - options.cp_inactive_iter_max;

		if (options.randomize_method == 2) generate_permutation(permutation, n);

		double _t[2];           // index 0: before calling real oracle
		double _lower_bound[2]; // index 1: after calling real oracle

		_t[0] = get_time();
		_lower_bound[0] = lower_bound_last;

		for (approx_pass=-1; approx_pass<approx_max; approx_pass++, total_pass++)
		{
			timestamp += (float) ( 0.5 / (approx_max+1) );

			if (options.randomize_method == 3) generate_permutation(permutation, n);

			for (_i=0; _i<n; _i++)
			{
				if (permutation)                        i = permutation[_i];
				else if (options.randomize_method == 0) i = _i;
				else                                    i = RandomInteger(n);
			
				double* current = terms[i]->current;

				if (approx_pass < 0) // call real oracle
				{
					max_fn(i, current_new); 
					AddCuttingPlane(i, current_new);
				}
				else  // call approximate oracle
				{
					if (options.kernel_max > 1)
					{
						SolveWithKernel(i, options.kernel_max);
						Multiply(w, current_sum, -lambda_mu_inv, d);
						terms[i]->RemoveUnusedPlanes();

						// averaging
						int p = (approx_pass < 0 || options.avg_flag == 1) ? 0 : 1;
						if (avg[p])
						{
							double gamma = 2.0 / (2 + (k_avg[p] ++));
							Interpolate(avg[p], current_sum, gamma, d+1);
						}

						continue;
					}

					int t = terms[i]->Maximize(w);
					terms[i]->UpdateStats(t);
					memcpy(current_new, terms[i]->a[t], (d+1)*sizeof(double));
					terms[i]->RemoveUnusedPlanes();
				}

				// min_{gamma \in [0,1]} B*gamma*gamma - 2*A*gamma
				double A = Op1(current, current_new, current_sum, d) + (current_new[d] - current[d])*lambda_mu; // <current-current_new,current_sum> + (b_new - b) * lambda
				double B = Op2(current, current_new, d); // ||current-current_new||^2
				double gamma;
				if (B<=0) gamma = (A <= 0) ? 0 : 1;
				else
				{
					gamma = A/B;
					if (gamma < 0) gamma = 0;
					if (gamma > 1) gamma = 1;
				}

				for (k=0; k<=d; k++)
				{
					double old = current[k];
					current[k] = (1-gamma)*current[k] + gamma*current_new[k];
					current_sum[k] += current[k] - old;
				}
				Multiply(w, current_sum, -lambda_mu_inv, d);

				// averaging
				int p = (approx_pass < 0 || options.avg_flag == 1) ? 0 : 1;
				if (avg[p])
				{
					gamma = 2.0 / (2 + (k_avg[p] ++));
					Interpolate(avg[p], current_sum, gamma, d+1);
				}
			}

			double t = get_time();
			lower_bound_last = (-Norm(w, d)*lambda_mu/2 + current_sum[d])*mu;

			if (approx_pass >= 0)
			{
				if ( (lower_bound_last - _lower_bound[1]) * (_t[1]-_t[0]) * options.approx_limit_ratio
				   < (_lower_bound[1]  - _lower_bound[0]) * (t-_t[1])      ) { approx_pass ++; break; }
			}

			_t[1] = t;
			_lower_bound[1] = lower_bound_last;
		}

		time_from_start = get_time() - time_start;
		if (options.callback_fn && (iter % options.callback_freq) == 0)
		{
			double t0 = get_time();

			bool res;
			double* w_tmp;
			double* current_sum_tmp;

			if (avg_num == 1)
			{
				current_sum_tmp = current_sum; w_tmp = w;
				current_sum = avg[0]; w = avg_buf;
				Multiply(w, current_sum, -lambda_mu_inv, d);
				res = (*options.callback_fn)(this);
				current_sum = current_sum_tmp; w = w_tmp;
			}
			else if (avg_num == 2)
			{
				current_sum_tmp = current_sum; w_tmp = w;
				current_sum = avg_buf; w = avg_buf + d+1;
				InterpolateBest(avg[0], avg[1], current_sum);
				Multiply(w, current_sum, -lambda_mu_inv, d);
				res = (*options.callback_fn)(this);
				current_sum = current_sum_tmp; w = w_tmp;
			}
			else
			{
				res = (*options.callback_fn)(this);
			}

			callback_time += get_time() - t0;

			if (!res) break;
		}
	}

	if (avg_num == 1)
	{
		memcpy(current_sum, avg[0], vec_size);
		Multiply(w, current_sum, -lambda_mu_inv, d);
	}
	else if (avg_num == 2)
	{
		InterpolateBest(avg[0], avg[1], current_sum);
		Multiply(w, current_sum, -lambda_mu_inv, d);
	}

	delete [] _buf0;
	return w;
}

void SVM::GetBounds(double& lower_bound, double& upper_bound)
{
	int i;
	double* tmp = new double[d+1];

	double norm = Norm(w, d)*lambda_mu/2;
	lower_bound = -norm + current_sum[d];
	upper_bound = 0;
	for (i=0; i<n; i++)
	{
		max_fn(i, tmp);
		upper_bound += DotProduct(w, tmp, d) + tmp[d];
	}
	upper_bound += norm;

	lower_bound *= mu;
	upper_bound *= mu;

	delete [] tmp;
}

void SVM::SolveWithKernel(int _i, int iter_max)
{
	Term* T = terms[_i];
	int num = T->num, i, t, iter;
	double** kk = T->products;
	double* ck = (double*) rbuf_SolveWithKernel.Alloc(3*num*sizeof(double)); // ck[i] = DotProduct(current, T->a[i], d)
	double* sk = ck + num; // sk[i] = DotProduct(current_sum, T->a[i], d)
	double cc = DotProduct(T->current, T->current, d);
	double cs = DotProduct(T->current, current_sum, d);
	double c_d = T->current[d];

	double* x = sk + num;
	double cx = 1;

	double gamma;

	for (i=0; i<num; i++)
	{
		ck[i] = DotProduct(T->current, T->a[i], d);
		sk[i] = DotProduct(current_sum, T->a[i], d);
		x[i] = 0;
	}

	for (iter=0; iter<iter_max; iter++)
	{
		if (iter > 0)
		{
			// current += gamma*(a[t] - current)
			// sum     += gamma*(a[t] - current)

			c_d += gamma*(T->a[t][d] - c_d);
			double cc_new = cc + 2*gamma*(ck[t] - cc) + gamma*gamma*(kk[t][t] - 2*ck[t] + cc);
			double cs_new = cs+ gamma*(ck[t] + sk[t] - cc - cs) + gamma*gamma*(kk[t][t] - 2*ck[t] + cc);
			cc = cc_new;
			cs = cs_new;

			for (i=0; i<num; i++)
			{
				if (kk[i][t] == NOT_YET_COMPUTED) kk[i][t] = DotProduct(T->a[i], T->a[t], d);
				double delta = gamma*(kk[i][t] - ck[i]);
				ck[i] += delta;
				sk[i] += delta;
			}

/*
double* current_tmp = new double[2*(d+1)];
double* current_sum_tmp = current_tmp + d+1;
for (i=0; i<=d; i++)
{
	current_sum_tmp[i] = current_sum[i] - T->current[i];
	current_tmp[i] = T->current[i] * cx;
}
for (t=0; t<num; t++)
{
	if (x[t] == 0) continue;
	for (i=0; i<=d; i++) current_tmp[i] += x[t]*T->a[t][i];
}
for (i=0; i<=d; i++)
{
	current_sum_tmp[i] += current_tmp[i];
}

printf("cc: %f %f\n", cc, DotProduct(current_tmp, current_tmp, d));
printf("cs: %f %f\n", cs, DotProduct(current_tmp, current_sum_tmp, d));
printf("c_d: %f %f\n", c_d, current_tmp[d]);
for (i=0; i<num; i++)
{
	printf("ck[%d]: %f %f\n", i, ck[i], DotProduct(current_tmp, T->a[i], d));
}
for (i=0; i<num; i++)
{
	printf("sk[%d]: %f %f\n", i, sk[i], DotProduct(current_sum_tmp, T->a[i], d));
}
getchar();
delete [] current_tmp;
*/
		}

		t = 0;
		double v_best;
		for (i=0; i<num; i++)
		{
			double v = -sk[i] * lambda_mu_inv + T->a[i][d];
			if (i==0 || v_best < v) { t = i; v_best = v; }
		}
		T->UpdateStats(t);

		if (kk[t][t] == NOT_YET_COMPUTED) kk[t][t] = DotProduct(T->a[t], T->a[t], d);

		// min_{gamma \in [0,1]} B*gamma*gamma - 2*A*gamma
		double A = cs - sk[t] + (T->a[t][d] - c_d)*lambda_mu;
		double B = cc + kk[t][t] - 2*ck[t];

		if (B<=0) gamma = (A <= 0) ? 0 : 1;
		else
		{
			gamma = A/B;
			if (gamma < 0) gamma = 0;
			if (gamma > 1) gamma = 1;
		}

		cx *= 1-gamma;
		for (i=0; i<num; i++) x[i] *= 1-gamma;
		x[t] += gamma;
	}

	for (i=0; i<=d; i++)
	{
		current_sum[i] -= T->current[i];
		T->current[i] *= cx;
	}
	for (t=0; t<num; t++)
	{
		if (x[t] == 0) continue;
		for (i=0; i<=d; i++) T->current[i] += x[t]*T->a[t][i];
	}
	for (i=0; i<=d; i++)
	{
		current_sum[i] += T->current[i];
	}
}

void SVM::InterpolateBest(double* s1, double* s2, double* s_best)
{
	double A = Op1(s1, s2, d) + (s2[d] - s1[d])*lambda_mu;
	double B = Op2(s1, s2, d);
	double gamma;
	if (B<=0) gamma = (A <= 0) ? 0 : 1;
	else
	{
		gamma = A/B;
		if (gamma < 0) gamma = 0;
		if (gamma > 1) gamma = 1;
	}

	int k;
	for (k=0; k<=d; k++)
	{
		s_best[k] = (1-gamma)*s1[k] + gamma*s2[k];
	}
}

