/*********************************************************
** Copyright (c) 2005
** University of Washington
** Licensed under the terms set forth by University of
** Washington. If you did not sign such a license, you
** are using this software/code illegally and you do not
** have permission to use, modify, or redistribute
** this or any files in this software package.
**
** File: Rintf_cluster.cpp  $Revision: 236 $
**
**********************************************************/
#include "ProgressThread.h"
#include "Cluster.h"
#include "VectorDataFunctions.h"
#include "BasicCluster.h"
#include "ClusterThread.h"
#include "Ordering.h"
#include "EisenOrdering.h"
#include "OptimalOrdering.h"
#include "ClusterException.h"

#include "Rintf_cluster.h"

extern "C"
{

  // Clustering thread

  // Starts a clustering thread.
  SEXP StartClusteringThread(SEXP sexpData, SEXP sexpRows, SEXP sexpCols,
			      SEXP sexpOrderMethod, 
			      SEXP sexpDistanceMethod)
  {
    SEXP sthreadid;

    int nProt = 0;
    PROTECT(sexpData = AS_NUMERIC(sexpData)); nProt++;
    PROTECT(sexpRows = AS_INTEGER(sexpRows)); nProt++;
    PROTECT(sexpCols = AS_INTEGER(sexpCols)); nProt++;
    PROTECT(sexpOrderMethod = AS_INTEGER(sexpOrderMethod)); nProt++;
    PROTECT(sexpDistanceMethod = AS_INTEGER(sexpDistanceMethod)); nProt++;
    PROTECT(sthreadid = NEW_INTEGER(1)); nProt++;

    double* pData = NUMERIC_POINTER(sexpData);
    int rows = INTEGER_VALUE(sexpRows);
    int cols = INTEGER_VALUE(sexpCols);
    ORDER_METHOD orderMethod = (ORDER_METHOD)INTEGER_VALUE(sexpOrderMethod);
    DISTANCE_METHOD distanceMethod = (DISTANCE_METHOD)INTEGER_VALUE(sexpDistanceMethod);

    if (NULL == pData)
      {
	Rf_error("pData is null");
	UNPROTECT(nProt);
	return R_NilValue;
      }
    if (NA_INTEGER == rows)
      {
	Rf_error("rows is not defined");
	UNPROTECT(nProt);
	return R_NilValue;
      }
    if (NA_INTEGER == cols)
      {
	Rf_error("cols is not defined", "");
	UNPROTECT(nProt);
	return R_NilValue;
      }
    
    /* Cluster the data.*/
    
    CDistanceFunction* pDist = NULL;
    COrdering *pOrdering = NULL;
    CCentroidLinkage* pCentroid = NULL;
    CBasicCluster* pCluster = NULL;
    TDoubleVector** vec = NULL;
    CClusterThread *pThread = NULL;
    CClusterException *pError = NULL;

    try {
      switch(distanceMethod) 
	{
	case(CORRELATION):
	  pDist = new CCorrelation(rows);
	  break;
	case(EUCLIDEAN):
	default:
	  pDist = new CEuclidean();
	  break;
	}

      switch(orderMethod)
	{
	case(EISEN):
	  pOrdering = new CEisenOrdering();
	  break;
	case(OPTIMAL):
	  pOrdering = new COptimalOrdering();
	  break;
	case(NONE):
	default:
	  break;
	}
      
      pCentroid = new CCentroidLinkage();
      pCluster = new CBasicCluster(pDist, pCentroid);
      
      /* Put the data into a format that the clustering class will understand*/
      vec = VectorizeRMatrix(pData, rows, cols);
      pCluster->CheckMemory(cols*sizeof(double), rows);

      pThread = new CClusterThread((CCluster*)pCluster, 
				   vec, rows, pOrdering, false);
      if (NULL != pOrdering)
	pOrdering->SetCallback(&UpdateProgressThreadProgress, (void*)pThread);
      pCluster->SetCallback(&UpdateProgressThreadProgress, (void*)pThread);

    } 
    catch(CClusterException* pEx) {

      pError = pEx;
    }
    
    if (NULL != pError)
      {
	// copy error message into something that will get deleted
	// when it goes out of scope
	string s(pError->Message());
	if(pError->Id() == CLUSTEX_OUTOFMEMORY)
	  s = "Not enough memory. Your dataset may be too large.";

	// Clean up
	delete pCluster;
	delete pDist;
	delete pOrdering;
	delete pThread;
	if (NULL != vec) {
	  for(int i = 0; i < rows; i++)
	    {
	      delete vec[i];
	    }
	  delete [] vec;
	}
	delete pError;

	// R error
	Rf_error(s.c_str());
	UNPROTECT(nProt);
	return R_NilValue;
      }

    int status = pThread->Start();
    if(0 != status)
      {
	// Clean up
	if(!pThread->GetIsStarted())
	  {
	    delete pCluster;
	    delete pDist;
	    delete pOrdering;
	    delete pThread;
	    if (NULL != vec) {
	      for(int i = 0; i < rows; i++)
		{
		  delete vec[i];
		}
	      delete [] vec;
	    }
	    Rf_error("Could not start thread (error number %d)", status);
	    UNPROTECT(nProt);
	    return R_NilValue;
	  }
	else
	  Rf_warning("Thread %d was started, but returned status code %d", 
		     pThread->GetThreadId(), status);
      }
    
    INTEGER_POINTER(sthreadid)[0] = (int)pThread->GetThreadId();

    UNPROTECT(nProt);
    return(sthreadid);
  }

  // Gathers the results of a finished clustering thread.
  SEXP FinishClusteringThread(SEXP sthreadid,
			      SEXP sexpMerge, 
			      SEXP sexpHeight, 
			      SEXP sexpOrder)
  {
    int nProt = 0;
    PROTECT(sthreadid = AS_INTEGER(sthreadid)); nProt++;
    THREAD_ID_T threadid = (THREAD_ID_T)INTEGER_VALUE(sthreadid);

    PROTECT(sexpMerge = AS_INTEGER(sexpMerge)); nProt++;
    PROTECT(sexpHeight = AS_NUMERIC(sexpHeight)); nProt++;
    PROTECT(sexpOrder = AS_INTEGER(sexpOrder)); nProt++;
    int* pMerge = INTEGER_POINTER(sexpMerge);
    double* pHeight = NUMERIC_POINTER(sexpHeight);
    int* pOrder = INTEGER_POINTER(sexpOrder); 

    CClusterNode* pResult;
    CBasicCluster* pCluster;

    // Get the thread
    CClusterThread* pThread = 
      (CClusterThread*)CProgressThreadManager::GetProgressThread(threadid);
    if (NULL == pThread)
      {
	Rf_error("Could not get information for thread id %d", threadid);
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if(pThread->GetIsCanceled())
      {
	Rf_error("Thread was canceled");
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if(!pThread->GetIsFinished())
      {
	Rf_error("Thread is still running");
	UNPROTECT(nProt);
	return R_NilValue;
      }

    // Check the return status
    int returnStatus = pThread->GetReturnStatus();
    if(0 != returnStatus)
      {
	Rf_error("An error occurred during clustering: %s", pThread->GetErrorMessage().c_str());
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if (NULL == (pResult = pThread->GetResult()))
      {
	Rf_error("Result tree is null for thread id %d", threadid);
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if (NULL == (pCluster = (CBasicCluster*)pThread->GetClusterer()))
      {
	Rf_error("Clustering object is null for thread id %d", threadid);
	UNPROTECT(nProt);
	return R_NilValue;
      }

  /* build the merge structure.
   * From the R help files:
   * 
   * "merge is an n-1 by 2 matrix. 
   * Row i of merge describes the merging of clusters
   * at step i of the clustering. If an element j in
   * the row is negative, then observation -j was merged
   * at this stage. If j is positive then the merge was 
   * with the cluster formed at the (earlier) stage j
   * of the algorithm. Thus negative entries in merge 
   * indicate agglomerations of singletons, and positive
   * entries indicate agglomerations of non-singletons."
   * 
   * for each item in the cluster "array"
   * if it is a leaf ignore it (i.e. start looping at n)
   * get index of left child; 
   * if it is < n (i.e. it's a leaf) then set the row first elt to -index
   * otherwise set the row first elt to n - index
   * same with right child/second row elt
   */
    int rows = pThread->GetDataCount();
    int heightLen = rows - 1;
    int nodeIndex = rows;
    int offset1 = rows - 1;
    for(int offset0 = 0; offset0 < rows-1; offset0++, offset1++, nodeIndex++)
      {
	CClusterNode* pNode = pCluster->GetNodeByIndex(nodeIndex);
	int index = pNode->GetLeftChild()->GetIndex()+1;
	pMerge[offset0] = (index <= rows) ? -index : index - rows;
	index = pNode->GetRightChild()->GetIndex()+1;
	pMerge[offset1] = (index <= rows) ? -index : index - rows;
      }
    
    
    /* Create the "height" structure. */
    for(int i = 0; i < heightLen; i++)
      {
	// get node n + i.
	CClusterNode *pNode = pCluster->GetNodeByIndex(rows+i);
	pHeight[i] = pCluster->GetDistanceFunction()->Distance
	  (pNode->GetLeftChild(), pNode->GetRightChild());
      }
    
    /* construct the "order" structure */
    CLeafOrdering ordering(pOrder);
    pResult->WalkLeaves(&ordering);
    
    /* Do not delete pThread or its members - in case user wants to access thread values later */

    SEXP ret;
    PROTECT(ret = NEW_INTEGER(0)); nProt++;
    UNPROTECT(nProt);
    
    return ret;
    
  }

  /*
   * Creates a hierarchical cluster of data
   * Constructs components of an R "hclust" structure
   * Does NOT do any leaf ordering
   * Assumes:
   * length of pMerge is (pDataLength - 1)*2
   * length of pHeight is pDataLength - 1
   * length of pOrder is pDataLength
   */
  SEXP Cluster(SEXP sexpData, SEXP sexpRows, SEXP sexpCols, SEXP sexpMerge, SEXP sexpHeight, SEXP sexpOrder, SEXP sexpOrderMethod, SEXP sexpDistanceMethod)
  //void Cluster(double *pData, int* pRows, int* pCols, int* pMerge, double* pHeight, int* pOrder, ORDER_METHOD* pOrderMethod)
  {
    
    int nProt = 0;
    PROTECT(sexpData = AS_NUMERIC(sexpData)); nProt++;
    PROTECT(sexpRows = AS_INTEGER(sexpRows)); nProt++;
    PROTECT(sexpCols = AS_INTEGER(sexpCols)); nProt++;
    PROTECT(sexpMerge = AS_INTEGER(sexpMerge)); nProt++;
    PROTECT(sexpHeight = AS_NUMERIC(sexpHeight)); nProt++;
    PROTECT(sexpOrder = AS_INTEGER(sexpOrder)); nProt++;
    PROTECT(sexpOrderMethod = AS_INTEGER(sexpOrderMethod)); nProt++;
    
    double* pData = NUMERIC_POINTER(sexpData);
    int* pMerge = INTEGER_POINTER(sexpMerge);
    double* pHeight = NUMERIC_POINTER(sexpHeight);
    int* pOrder = INTEGER_POINTER(sexpOrder);
    int rows = INTEGER_VALUE(sexpRows);
    int cols = INTEGER_VALUE(sexpCols);
    ORDER_METHOD orderMethod = (ORDER_METHOD)INTEGER_VALUE(sexpOrderMethod);
    DISTANCE_METHOD distanceMethod = (DISTANCE_METHOD)INTEGER_VALUE
      (sexpDistanceMethod);
    
    if (NULL == pData)
      {
	Rf_error("pData is null");
	UNPROTECT(nProt);
	return R_NilValue;
      }
    if (NA_INTEGER == rows)
      {
	Rf_error("rows is not defined");
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if (NA_INTEGER == cols)
      {
	Rf_error("cols is not defined");
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if (NULL == pMerge)
      {
	Rf_error("pMerge is null");
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if (NULL == pHeight)
      {
	Rf_error("pHeight is null");
	UNPROTECT(nProt);
	return R_NilValue;
      }

    if (NULL == pOrder)
      {
	Rf_error("pOrder is null");
	UNPROTECT(nProt);
	return R_NilValue;
      }
     
    CDistanceFunction* pDist = NULL;
    TDoubleVector** vec = NULL;
    CClusterNode* pResult = NULL;
    COrdering* pOrdering = NULL;

    int heightLen = rows - 1;
    

    /* Cluster the data.*/
    CClusterException* pError = NULL;

    try {
      pDist = NULL;
      switch(distanceMethod)
	{
	case(CORRELATION):
	  pDist = new CCorrelation(rows);
	  break;
	case(EUCLIDEAN):
	default:
	  pDist = new CEuclidean();
	  break;
	}
      
      if (NULL == pDist)
	{
	  Rf_error("Out of memory! Could not allocate distance object");
	  UNPROTECT(nProt);
	  return R_NilValue;
	}
      
      CCentroidLinkage centroid;
      CBasicCluster cluster(pDist, &centroid);
      
      /* Put the data into a format that the clustering class will understand*/
      vec = VectorizeRMatrix(pData, rows, cols);
      
      cluster.CheckMemory(cols*sizeof(double), rows);
      pResult = cluster.Cluster((void**)vec, rows);
      
      if (NULL == pResult)
	__throw_cluster_ex(CLUSTEX_OUTOFMEMORY, "Ran out of memory.");
      
      /* build the merge structure.
       * From the R help files:
       * 
       * "merge is an n-1 by 2 matrix. 
       * Row i of merge describes the merging of clusters
       * at step i of the clustering. If an element j in
       * the row is negative, then observation -j was merged
       * at this stage. If j is positive then the merge was 
       * with the cluster formed at the (earlier) stage j
       * of the algorithm. Thus negative entries in merge 
       * indicate agglomerations of singletons, and positive
       * entries indicate agglomerations of non-singletons."
       * 
       * for each item in the cluster "array"
       * if it is a leaf ignore it (i.e. start looping at n)
       * get index of left child; 
       * if it is < n (i.e. it's a leaf) then set the row first elt to -index
       * otherwise set the row first elt to n - index
       * same with right child/second row elt
       */
      int nodeIndex = rows;
      int offset1 = rows - 1;
      for(int offset0 = 0; offset0 < rows-1; offset0++, offset1++, nodeIndex++)
	{
	  CClusterNode* pNode = cluster.GetNodeByIndex(nodeIndex);
	  int index = pNode->GetLeftChild()->GetIndex()+1;
	  pMerge[offset0] = (index <= rows) ? -index : index - rows;
	  index = pNode->GetRightChild()->GetIndex()+1;
	  pMerge[offset1] = (index <= rows) ? -index : index - rows;
	}
      
      switch(orderMethod)
	{
	case(EISEN):
	  pOrdering = new CEisenOrdering();
	  if (NULL == pOrdering)
	    {
	      Rf_error("Out of memory! Could not allocate ordering object");
	      delete pDist;
	      UNPROTECT(nProt);
	      return R_NilValue;
	    }
	  break;
	case(OPTIMAL):
	  pOrdering = new COptimalOrdering();
	  if (NULL == pOrdering)
	    {
	      Rf_error("Out of memory! Could not allocate ordering object");
	      delete pDist;
	      UNPROTECT(nProt);
	      return R_NilValue;
	    }
	  break;
	case(NONE):
	default:
	  break;
	}
      
      if (NULL != pOrdering)
	{
	  pOrdering->OrderLeaves(pResult, cluster.GetDistanceMatrix());
	  delete pOrdering;
	  pOrdering = NULL;
	}
      
      /* construct the "order" structure */
      CLeafOrdering ordering(pOrder);
      pResult->WalkLeaves(&ordering);
      
      /* Create the "height" structure. */
      for(int i = 0; i < heightLen; i++)
	{
	  // get node n + i.
	  CClusterNode *pNode = cluster.GetNodeByIndex(rows+i);
	  pHeight[i] = cluster.GetDistanceFunction()->Distance
	    (pNode->GetLeftChild(), pNode->GetRightChild());
	}
    } 
    catch(CClusterException* pEx) {
      pError = pEx;
    }

    /* pDist, vec, pResult, pOrdering */
    for(int i = 0; i < rows; i++)
      {
	delete vec[i];
      }
    delete [] vec;
    
    delete pDist;
    if (NULL != pResult)
      pResult->RecursiveDelete();
    delete pResult;
    delete pOrdering;

    if (NULL != pError)
      {
	string s(pError->Message());
	if(pError->Id() == CLUSTEX_OUTOFMEMORY)
	  s = "Not enough memory. Your dataset may be too large.";
	delete pError;
	Rf_error(s.c_str());
	UNPROTECT(nProt);
	return R_NilValue;
      }

    SEXP ret;
    PROTECT(ret = NEW_INTEGER(23)); nProt++;
    UNPROTECT(nProt);
    
    return ret;
  }
  
}

TDoubleVector** VectorizeRMatrix(double* rmatrix, int nrow, int ncol)
{
  // convert a columnwise array of doubles to an
  // array of TDoubleVectors

  TDoubleVector** vec = new TDoubleVector*[nrow];
  for(int row = 0; row < nrow; row++)
    {
      vec[row] = new TDoubleVector();
    }

  int index = 0;
  for(int col = 0; col < ncol; col++)
    {
      for(int row = 0; row < nrow; row++, index++)
	{
	  vec[row]->push_back(rmatrix[index]);
	}
    }

  return vec;
}

double GetHeight(CClusterNode* pNode, double* pHeights, int heightLen)
{
  assert(NULL != pNode);
  assert(NULL != pHeights);
  assert(heightLen > 0);

  if (pNode->IsLeaf())
    {
      return 1;
    }

  int index = pNode->GetIndex() - heightLen - 1;
  assert(index < heightLen);
  assert(index >= 0);

  double lefth = 1, righth = 1;
  if (NULL != pNode->GetLeftChild())
    {
      lefth = GetHeight(pNode->GetLeftChild(), pHeights, heightLen);
    }
  if (NULL != pNode->GetRightChild())
    {
      righth = GetHeight(pNode->GetRightChild(), pHeights, heightLen);
    }
  
  pHeights[index] = (lefth > righth) ? lefth+1 : righth+1;
  return pHeights[index];
}

void CLeafOrdering::Go(CClusterNode* pNode)
{
 m_pOrder[m_current++] = pNode->GetIndex();
}
