/*********************************************************
** Copyright (c) 2005
** University of Washington
** Licensed under the terms set forth by University of
** Washington. If you did not sign such a license, you
** are using this software/code illegally and you do not
** have permission to use, modify, or redistribute
** this or any files in this software package.
**
** File: ClusterTest.h
**
**********************************************************/
#ifndef __CLUSTERTEST__
#define __CLUSTERTEST__

#include "UnitTestClass.h"
#include "ClusterNode.h"
#include "TestException.h"
#include "BasicCluster.h"
#include "ClusterException.h"
#include "OptimalOrdering.h"
#include <stdint.h>
#include <iostream>
#include <sstream>
#include <cmath>

using namespace std;

class CAbsoluteDifference : public CDistanceFunction
{
 public:
  double Distance(CClusterNode* pNode1, CClusterNode* pNode2)
  {
    assert (NULL != pNode1 && NULL != pNode2);
    double* pData1 = (double*)pNode1->GetData();
    double* pData2 = (double*)pNode2->GetData();

    assert (NULL != pData1 && NULL != pData2);
    return fabs( *pData1 - *pData2 );
  }
  uintmax_t GetMemoryRequirement() { return 0; }
};

class CAverage : public CCombineDataFunction
{
 public:
  void* CombineData(CClusterNode* pNode1, CClusterNode* pNode2)
  {
    if (NULL == pNode1 || NULL == pNode2)
      {
	throw new exception(); // todo: prettier exceptions
      }
    double d1 = *( (double*)pNode1->GetData() );
    double d2 = *( (double*)pNode2->GetData() );

    return new double( (d1 + d2)/2.0 );
  }
  uintmax_t GetMemoryRequirement() { return 0; }
};

class CLeafAverage : public CCombineDataFunction
{
 public:
  void* CombineData(CClusterNode* pNode1, CClusterNode* pNode2)
  {
    if (NULL == pNode1 || NULL == pNode2)
      {
	throw new exception(); // todo: prettier exceptions
      }
    int leafCount1 = pNode1->IsLeaf() ? 1 : pNode1->GetLeftLeaves()->size() + pNode1->GetRightLeaves()->size();
    int leafCount2 = pNode2->IsLeaf() ? 1 : pNode2->GetLeftLeaves()->size() + pNode2->GetRightLeaves()->size();

    double avg = (*((double*)pNode1->GetData()) * leafCount1 + 
		*((double*)pNode2->GetData()) * leafCount2)  
      / (leafCount1 + leafCount2);

    return new double(avg);

  }
  uintmax_t GetMemoryRequirement() { return 0; }
};

class CClusterTest : public CUnitTestClass
{
 public:
  CClusterTest() : CUnitTestClass("ClusterTest")
  {
  }

 protected:
  void RunTests()
  {
    TestCluster();
    TestLargeCluster();
    TestMemoryRequirement();
    TestCheckMemory();
  }

 private:
  void TestCluster()
  {
    BeginTest("TestCluster");
    bool success = true;
    
    double* pData[4] = {new double(3), new double(7), new double(6), new double(8)};

    CAbsoluteDifference diff;
    CAverage combine;
    CBasicCluster cluster((CDistanceFunction*)&diff, (CCombineDataFunction*)&combine);
    try
    {
      CClusterNode* root = cluster.Cluster((void**)&pData, 4);

    /* check results */
      CClusterNode *n;

      n = root->GetLeftChild();
      if(*((double*)n->GetData()) != 3.0)
	{
	  __throw_message("left leaf should have '3' as data but doesn't");
	}

      n = root->GetRightChild();
      if (n == 0)
	__throw_message("right child of root is null");

      n = n->GetLeftChild();
      if(n == 0)
	  __throw_message("root->right->left is null");
      if (*((double*)n->GetData()) != 8)
	__throw_message("root->right->left data is not 8 as expected");

      n = n->GetParent()->GetRightChild();
      if (n == 0)
	__throw_message("root->right->right is null");
      n = n->GetLeftChild();
      if (n == 0)
	__throw_message("root->right->right->left is null");
      if (*((double*)n->GetData()) != 7)
	__throw_message("root->right->right->left->data is not 7");
      
      n = n->GetParent()->GetRightChild();
      if (n == 0)
	__throw_message("root->right->right->right is null");
      if (*((double*)n->GetData()) != 6)
	__throw_message("root->right->right->right->data is not 6");


      /* attempt ordering! */
      COptimalOrdering ordering;
      ordering.OrderLeaves(root, cluster.GetDistanceMatrix());
      n = root->GetLeftChild();
      if(*((double*)n->GetData()) != 3.0)
	{
	  __throw_message("left leaf should have '3' as data but doesn't");
	}

      n = root->GetRightChild();
      if (n == 0)
	__throw_message("right child of root is null");

      n = n->GetLeftChild();
      if(n == 0)
	  __throw_message("root->right->left is null");
      n = n->GetLeftChild();
      if (n == 0)
	__throw_message("root->right->left->left is null");
      if (*((double*)n->GetData()) != 6)
	__throw_message("root->right->left->left->data is not 6");
      
      n = n->GetParent()->GetRightChild();
      if (n == 0)
	__throw_message("root->right->left->right is null");
      if (*((double*)n->GetData()) != 7)
	__throw_message("root->right->left->right->data is not 7");


      n = n->GetParent()->GetParent()->GetRightChild();
      if (n == 0)
	__throw_message("root->right->right is null");
      if (*((double*)n->GetData()) != 8)
	__throw_message("root->right->right data is not 8 as expected");

      
    }
    catch(CTestException *pEx)
    {
      success = false;
      cout << pEx->Message() << "\n";
      delete pEx;
    }
    catch(CClusterException *pEx)
    {
      success = false;
      cout << pEx->Message() << "\n";
      delete pEx;
    }
    catch(...)
    {
      success = false;
      cout << "something's wrong\n";

    }
    
    for(int i = 0; i < 4; i++)
      {
	free(pData[i]);
      }

    EndTest(success);
  }

  void TestLargeCluster()
  {
    bool success = true;
    BeginTest("TestLargeCluster");

    CClusterNode* root = NULL;

    try
    {
      string expected ="9 8 7 6 5 4 3 2 1 ";
      double pData[] = {5, 7, 1, 2, 4, 9, 8, 3, 6};
      double* ppData[9];
      for(int i = 0 ; i < 9 ; i++)
	{
	  ppData[i] = &pData[i];
	}

      CAbsoluteDifference distanceFunc;
      CLeafAverage combineFunc;
      CLeavesToString<double> func;
      
      CBasicCluster cluster((CDistanceFunction*)&distanceFunc, (CCombineDataFunction*)&combineFunc);
      root = cluster.Cluster((void**)ppData, 9);
      COptimalOrdering ordering;
      ordering.OrderLeaves(root, cluster.GetDistanceMatrix());
      
      root->WalkLeaves(&func);
      cout << func.GetString();
      if(func.GetString() != expected)
	__throw_message("leaf ordering was not as expected");
    }
    catch(CTestException* pEx)
    {
      success = false;
      cout << pEx->Message() << "\n";
      delete pEx;
    }

    delete root;

    EndTest(success);
  }

  void TestMemoryRequirement()
  {
    bool success = true;
    BeginTest("TestMemoryRequirement");

    uintmax_t datalen = 1000;
    uintmax_t nodecount = 1999;
    uintmax_t datasize = 12;
    //    int expected = sizeof(double)*(nodecount*nodecount + 60) + (sizeof(CClusterNode) + sizeof(void*) + 12 + 2*sizeof(char) + 2*sizeof(double))*nodecount + 2*sizeof(NodeVector::iterator);
    uintmax_t expected = sizeof(double)*nodecount*nodecount + nodecount*(sizeof(CClusterNode) + sizeof(CClusterNode*) + sizeof(char)) + sizeof(double)*70 + 4*sizeof(TDoubleVector::iterator);
    try
    {
      CEuclidean distanceFunc;
      CCentroidLinkage combineFunc;
      
      CBasicCluster cluster((CDistanceFunction*)&distanceFunc, (CCombineDataFunction*)&combineFunc);

      uintmax_t actual = distanceFunc.GetMemoryRequirement();
      uintmax_t expected = 2*(sizeof(TDoubleVector::iterator)) + 10*sizeof(double);
      if (actual != expected)
	{
	  cout << "expected memory requirement for distance function " << expected << "; actual " << actual << "\n";
	  __throw_message("Memory requirement was not as expected");
	}

      actual = combineFunc.GetMemoryRequirement();
      expected = 2*(sizeof(TDoubleVector::iterator)) + 10*sizeof(double);
      if (actual != expected)
	{
	  cout << "expected memory requirement for combine function " << expected << "; actual " << actual << "\n";
	  __throw_message("Memory requirement was not as expected");
	}

      actual = cluster.GetMemoryRequirement(datasize, datalen);
      expected = combineFunc.GetMemoryRequirement() + distanceFunc.GetMemoryRequirement() + CUpperDiagonalMatrix::GetMemoryRequirement(nodecount) + sizeof(double)*50 + (sizeof(CClusterNode) + sizeof(CClusterNode*) + datasize + sizeof(char))*nodecount;

      if (actual != expected)
	{
	  cout << "expected memory requirement " << expected << "; actual " << actual << "\n";
	  __throw_message("Memory requirement was not as expected");
	}
    }
    catch(CTestException* pEx)
    {
      success = false;
      cout << pEx->Message() << "\n";
      delete pEx;
    }

    EndTest(success);
  }

  void TestCheckMemory()
  {
    bool success = true;
    BeginTest("TestCheckMemory");

    uintmax_t datalen = 200000;
    uintmax_t datasize = 228*sizeof(double);
    try
    {
      CEuclidean distanceFunc;
      CCentroidLinkage combineFunc;

      uintmax_t expected = 320000000000LL;
      uintmax_t actual = CUpperDiagonalMatrix::GetMemoryRequirement(datalen);
      __test_assert_nomsg(actual == expected);

      CBasicCluster cluster((CDistanceFunction*)&distanceFunc, (CCombineDataFunction*)&combineFunc);

      do {
	try {
	  // This checks for a huge amount of memory & should fail
	  cluster.CheckMemory(datasize, datalen);
	  success = false;
	  __throw_message("Memory check did not fail (but should have)");
	}
	catch(CClusterException* pEx)
	{
	  delete pEx;
	}
      } while(0);

      do {
	try {
	  // This checks for a huge amount of memory & should fail
	  cluster.CheckMemory(UINTMAX_MAX, 1);
	  success = false;
	  __throw_message("Memory check 2 did not fail (but should have)");
	}
	catch(CClusterException* pEx)
	{
	  delete pEx;
	}
      } while(0);
    }
    catch(CTestException* pEx)
    {
      success = false;
      cout << pEx->Message() << "\n";
      delete pEx;
    }

    EndTest(success);
  }

};

#endif
