/*********************************************************
** Copyright (c) 2005
** University of Washington
** Licensed under the terms set forth by University of
** Washington. If you did not sign such a license, you
** are using this software/code illegally and you do not
** have permission to use, modify, or redistribute
** this or any files in this software package.
**
** File: VectorDataFunctionsTest.h
**
**********************************************************/
#ifndef __VECTORDATAFUNCTIONSTEST__
#define __VECTORDATAFUNCTIONSTEST__

#include "UnitTestClass.h"
#include "VectorDataFunctions.h"
#include "TestException.h"
#include <iostream>
#include <string>
#include <sstream>
#include <cmath>

using namespace std;

class CVectorDataFunctionsTest : public CUnitTestClass
{
 public:
  CVectorDataFunctionsTest() : CUnitTestClass("VectorDataFunctionsTest")
  {
  }

 protected:
  void RunTests()
  {
    TestCentroidLinkageDistance();
  }

 private:
  void TestCentroidLinkageDistance()
  {
    BeginTest("TestCentroidLinkageDistance");
    bool success = true;
    TDoubleVector data1, data2, data3;
    
    // vector is {2, 3, 5, 7}
    // avg = 4.25
    // s = sqrt(4.19666) ~= 2.217
    data1.push_back(2);
    data1.push_back(3);
    data1.push_back(5);
    data1.push_back(7);

    // vector 2 is {3, 5, 6, 8}
    // avg = 5.5
    // s = sqrt(13/3) ~= 2.0816
    data2.push_back(3);
    data2.push_back(5);
    data2.push_back(6);
    data2.push_back(8);

    // correlation ~= 0.975
    // 1 - correlation = 0.025

    // average of data1, data2 = {2.5, 4, 5.5, 7.5};
    double ex_avg12[] = {2.5, 4, 5.5, 7.5};

    // vector is { 4, 3, 2, 1}
    data3.push_back(4);
    data3.push_back(3);
    data3.push_back(2);
    data3.push_back(1);

    // average of data1, data2, and data3 = 3, 3.6666, 4.33333, 5.333333
    double ex_avg123[] = {3, 3.6666666, 4.3333333, 5.3333333};

    TDoubleVector *pAvg123 = NULL, *pAvg12 = NULL;

    CClusterNode node1((void*)&data1, 0);
    CClusterNode node2((void*)&data2, 1);
    CClusterNode node3((void*)&data3, 2);
    
    TDoubleVector::iterator iter;
    try
    {
      CCorrelation corr(3);
      CCentroidLinkage cl;
      double dist = corr.Distance(&node1, &node2);

      // check that we are within 0.0001 of right answer
      double expected = 0.02508650;
      double diff = fabs(expected - dist);

      if (diff > 0.0001)
	{
	  cout << "expected = " << expected << ", computed = " << dist << "\n";
	  __throw_message("distances not equal");
	}

      pAvg12 = (TDoubleVector*)cl.CombineData(&node1, &node2);
      if(NULL == pAvg12)
	__throw_message("CombineData returned null");
      if(4 != pAvg12->size())
	__throw_message("CombineData does not have 4 elements");
      // check averages
      iter = pAvg12->begin();
      for(int i=0;iter != pAvg12->end();i++,iter++)
	{
	  if( fabs(ex_avg12[i] - *iter) > 0.0001)
	    {
	      cout << "element " << i << " is " << *iter << "; expected " << ex_avg12[i];
	      __throw_message("incorrect average");
	    }
	}

      CClusterNode nodeAvg12(pAvg12, 0, &node1, &node2);
      pAvg123 = (TDoubleVector*)cl.CombineData(&nodeAvg12, &node3);
      if(NULL == pAvg123)
	__throw_message("CombineData returned null");
      if(4 != pAvg123->size())
	__throw_message("CombineData does not have 4 elements");
      // check averages
      iter = pAvg123->begin();
      for(int i=0;iter != pAvg123->end();i++,iter++)
	{
	  if( fabs(ex_avg123[i] - *iter) > 0.0001)
	    {
	      cout << "element " << i << " is " << *iter << "; expected " << ex_avg123[i];
	      __throw_message("incorrect average");
	    }
	}

      CEuclidean euc;
      dist = euc.Distance(&node1, &node2);
      double expDist = sqrt(7.0);
      if( fabs(dist - expDist) > 0.1)
	{
	  cout << "euclidean dist(1,2) calculated as  " << dist << "; expected " << expDist << "\n";
	  __throw_message("incorrect distance");
	}
      dist = euc.Distance(&node2, &node3);
      expDist = sqrt(70.0);
      if( fabs(dist - expDist) > 0.1)
	{
	  cout << "euclidean dist(2,3) calculated as  " << dist << "; expected " << expDist << "\n";
	  __throw_message("incorrect distance");
	}

    }
    catch(CTestException *pEx)
    {
      success = false;
      cout << pEx->Message() << "\n";
      delete pEx;
    }

    if (NULL != pAvg12)
      delete pAvg12;
    if (NULL != pAvg123)
      delete pAvg123;

    EndTest(success);
  }


};

#endif
