/*********************************************************
** Copyright (c) 2005
** University of Washington
** Licensed under the terms set forth by University of
** Washington. If you did not sign such a license, you
** are using this software/code illegally and you do not
** have permission to use, modify, or redistribute
** this or any files in this software package.
**
** File: OptimalOrdering.cpp
**
**********************************************************/
#include "OptimalOrdering.h"
#include "ClusterException.h"
#include "Perf.h"
#include "MsgIds.h"
#include "CostMap.h"

#include <map>
#include <vector>
using namespace std;


/* Given a hierarchical cluster, finds the optimal leaf ordering. */
void COptimalOrdering::OrderLeaves(CClusterNode* pCluster, CMatrix* pDistances)
{
  if (NULL == pCluster)
    {
      throw new CClusterException(CLUSTEX_NULLARG, 
				  "argument pCluster may not be null");
    }

  if (NULL != m_ppCosts)
    {
      throw new CClusterException(CLUSTEX_REUSEDMETHOD,
				  "OrderLeaves may not be called more than once");
    }

  CreateInternals(pCluster, pDistances);

  CClusterNode *pBestU, *pBestW;

  MARK_TIME("Compute costs...");
  m_iNodeTrackerMax = (m_iLeafCount * m_iLeafCount);
  m_iNodeTracker = 0;
  ComputeCosts(pCluster);
  MARK_TIME("Get best leaf pair...");
  GetBestPair(pCluster, &pBestU, &pBestW);
  MARK_TIME("Backtrack...");
  m_iNodeTracker = 0;
  Backtrack(pCluster, pBestU, pBestW);
  MARK_TIME("Done ordering.");
}

void COptimalOrdering::CreateInternals(CClusterNode* pRoot, CMatrix* pDistances)
{ 
  m_iLeafCount = pRoot->GetLeafCount();
  m_iNodeCount = 2*m_iLeafCount - 1;
  m_pNode = pRoot;
  m_pDistances = pDistances;

  m_ppCosts = new double*[m_iLeafCount];
  for(int i = 0 ; i < m_iLeafCount; i++)
    m_ppCosts[i] = new double[m_iLeafCount];

}

void COptimalOrdering::DeleteInternals()
{
  if (NULL != m_ppCosts)
    {
      for(int i = 0 ; i < m_iLeafCount; i++)
	delete m_ppCosts[i];
      delete m_ppCosts;
      m_ppCosts = NULL;
    }
  m_pNode = NULL;
  m_pDistances = NULL;
  m_iLeafCount = 0;
}

void COptimalOrdering::_ComputeCost(int uIndex, int wIndex, CCostMap* pMapM, CCostMap* pMapK, double C)
{
  double umCost, wkCost;
  int mIndex, kIndex;

  double curMin = -1; 
  // for all m in ordered list of costs for u, M
  for(CCostMap::iterator miter = pMapM->begin(); miter != pMapM->end(); miter++)
    {
      umCost = (*miter).first;
      mIndex = (*miter).second;
      
      // k0 = first k in ordered list of costs for w, K
      CCostMap::iterator kiter = pMapK->begin();
      wkCost = (*kiter).first;
      kIndex = (*kiter).second;
      
      // C = min distance between leaf of left child and leaf of right child
      // calculated above
      
      // if cost(u,m) + cost(w, k0) + C >= curMin
      if(curMin >= 0 && umCost + wkCost + C >= curMin)
	{
	  // then cost(u,w) = curMin and break
	  break;
	}
      // for all k in ordered list of costs for w, K
      for (;kiter != pMapK->end();kiter++)
	{
	  wkCost = (*kiter).first;
	  kIndex = (*kiter).second;

	  //double cc = costUM + GetCost(wIndex, kIndex);
	  double cc = umCost + wkCost;
	  // if cost(u,m) + cost(w,k) + C >= curMin
	  if (curMin >= 0 && cc + C >= curMin)
	    {
	      // then break (out of k loop)
	      break;
	    }
	  // candidate = cost(u,m) + cost(w,k) + distance(m,k)
	  double candidate = cc + (kIndex == mIndex ? 0 : m_pDistances->GetValue(mIndex, kIndex));
	  // if curMin > candidate
	  if (curMin < 0 || candidate < curMin)
	    {
	      // then curMin = candidate
	      curMin = candidate;
	    }
	} // k loop
    } // m loop
  SetCost(uIndex, wIndex, curMin);
  SetCost(wIndex, uIndex, curMin);
  DoCallback(m_iNodeTrackerMax, m_iNodeTracker++, OPTORD_COMPUTECOSTS);
}

void COptimalOrdering::ComputeCosts(CClusterNode* pCluster)
{
  if (NULL == pCluster)
    __throw_cluster_ex(CLUSTEX_NULLARG, "pCluster is NULL");
  if (pCluster->IsLeaf())
    {
      SetCost(pCluster->GetIndex(), pCluster->GetIndex(), 0);
      DoCallback(m_iNodeTrackerMax, m_iNodeTracker++, OPTORD_COMPUTECOSTS);
      return;
    }

  if (NULL != pCluster->GetLeftChild())
    {
      ComputeCosts(pCluster->GetLeftChild());
    }

  if (NULL != pCluster->GetRightChild())
    {
      ComputeCosts(pCluster->GetRightChild());
    }


  CClusterNode *u, *w;
  int uIndex, wIndex;

  double C = -1; // min distance between leaf of left child and leaf of right child
  // for all u in left child
  for(NodeVector::iterator uiter = pCluster->GetLeftLeaves()->begin(); 
      uiter != pCluster->GetLeftLeaves()->end(); 
      uiter++)
    {
      u = *uiter;
      uIndex = u->GetIndex();
      // for all w in right child
      for(NodeVector::iterator witer = pCluster->GetRightLeaves()->begin(); 
	      witer != pCluster->GetRightLeaves()->end(); 
	      witer++)
	{
	  w = *witer;
	  double dist = m_pDistances->GetValue(uIndex, w->GetIndex());
	  if (C == -1 || dist < C)
	    C = dist;
	}
    }

  NodeVector* vLL = pCluster->GetLeftChild()->GetLeftLeaves(); 
  NodeVector* vLR = pCluster->GetLeftChild()->GetRightLeaves(); 
  NodeVector* vRR = pCluster->GetRightChild()->GetRightLeaves();
  NodeVector* vRL = pCluster->GetRightChild()->GetLeftLeaves();
  bool bLeftLeaf = false;
  bool bRightLeaf = false;
  if (pCluster->GetLeftChild()->IsLeaf())
    {
      // "cost" is M(leaf, leaf) + M(w, k) + S(m, k)
      vLL = vLR = new NodeVector(1);
      vLL->insert(vLL->end(), pCluster->GetLeftChild());
      bLeftLeaf = true;
    }

  if (pCluster->GetRightChild()->IsLeaf())
    {
      vRL = vRR = new NodeVector(1);
      vRL->insert(vRL->end(), pCluster->GetRightChild());
      bRightLeaf = true;
    }

  int index;

  // for all u in vLL
  for(NodeVector::iterator uiter = vLL->begin(); uiter != vLL->end(); uiter++)
    {
      uIndex = (*uiter)->GetIndex();

      CCostMap mapM;
      // the possible M's are in vLR
      for(NodeVector::iterator miter = vLR->begin(); miter != vLR->end(); miter++)
	{
	  index = (*miter)->GetIndex();
	  mapM.insert(pair<double, int>
		      (GetCost(uIndex, index), index));
	}
      mapM.sort();

      // for all w in vRL
      for (NodeVector::iterator witer = vRL->begin(); witer != vRL->end(); witer++)
	{
	  wIndex = (*witer)->GetIndex();
	  // compute cost given m in vLR and k in vRR
	  CCostMap mapK;
	  for(NodeVector::iterator kiter = vRR->begin(); kiter != vRR->end(); kiter++)
	    {
	      index = (*kiter)->GetIndex();
	      mapK.insert(pair<double, int>
			  (GetCost(wIndex, index), index));
	    }
	  mapK.sort();
	  _ComputeCost(uIndex, wIndex, &mapM, &mapK, C);

	}
      if (!bRightLeaf)
	{
	  // for all w in vRR
	  for (NodeVector::iterator witer = vRR->begin(); witer != vRR->end(); witer++)
	    {
	      wIndex = (*witer)->GetIndex();
	      // compute cost given m in vLR and k in vRL
	      CCostMap mapK;
	      for(NodeVector::iterator kiter = vRL->begin(); kiter != vRL->end(); kiter++)
		{
		  index = (*kiter)->GetIndex();
		  mapK.insert(pair<double, int>
			      (GetCost(wIndex, index), index));
		}
	      mapK.sort();
	      _ComputeCost(uIndex, wIndex, &mapM, &mapK, C);
	      
	    }
	}
    }
  
  // avoid repeating if left node is a leaf
  if (!bLeftLeaf)
    {
      // for all u in vLR
      for(NodeVector::iterator uiter = vLR->begin(); uiter != vLR->end(); uiter++)
	{
	  uIndex = (*uiter)->GetIndex();
	  
	  CCostMap mapM;
	  // the possible M's are in vLL
	  for(NodeVector::iterator miter = vLL->begin(); miter != vLL->end(); miter++)
	    {
	      index = (*miter)->GetIndex();
	      mapM.insert(pair<double, int>
			  (GetCost(uIndex, index), index));
	    }
	  mapM.sort();

	  // for all w in vRL
	  for (NodeVector::iterator witer = vRL->begin(); witer != vRL->end(); witer++)
	    {
	      wIndex = (*witer)->GetIndex();
	      // compute cost given m in vLR and k in vRR
	      // the possible k's are in vRR
	      CCostMap mapK;
	      for(NodeVector::iterator kiter = vRR->begin(); kiter != vRR->end(); kiter++)
		{
		  index = (*kiter)->GetIndex();
		  mapK.insert(pair<double, int>
			      (GetCost(wIndex, index), index));
		}
	      mapK.sort();
	      _ComputeCost(uIndex, wIndex, &mapM, &mapK, C);
	      
	    }
	  
	  if (!bRightLeaf)
	    {
	      // for all w in vRR
	      for (NodeVector::iterator witer = vRR->begin(); witer != vRR->end(); witer++)
		{
		  wIndex = (*witer)->GetIndex();
		  // compute cost given m in vLR and k in vRL
		  // the possible k's are in vRL
		  CCostMap mapK;
		  for(NodeVector::iterator kiter = vRL->begin(); kiter != vRL->end(); kiter++)
		    {
		      index = (*kiter)->GetIndex();
		      mapK.insert(pair<double, int>
				  (GetCost(wIndex, index), index));
		    }
		  mapK.sort();
		  _ComputeCost(uIndex, wIndex, &mapM, &mapK, C);
		  
		}
	    }
	}
    }
  
  if (bLeftLeaf)
    delete vLL;

  if (bRightLeaf)
    delete vRL;

}

void COptimalOrdering::GetBestPair(CClusterNode* pRoot, CClusterNode** ppU, CClusterNode **ppW)
{

  if (NULL == ppU || NULL == ppW || NULL == pRoot)
    throw new CClusterException(CLUSTEX_NULLARG, "null pointer arguments");

  int tracker = 0;
  int trackerTotal = pRoot->GetLeftLeaves()->size()*pRoot->GetRightLeaves()->size();

  *ppU = (*ppW = NULL);

  int uIndex;
  double cost, minCost = -1;
  
  // for each  u in left leaves and w in right leaves, get its cost & compare to minimum
  for(NodeVector::iterator uIter = pRoot->GetLeftLeaves()->begin();
      uIter != pRoot->GetLeftLeaves()->end();
      uIter++)
    {
      uIndex = (*uIter)->GetIndex();
      for(NodeVector::iterator wIter = pRoot->GetRightLeaves()->begin();
	  wIter != pRoot->GetRightLeaves()->end();
	  wIter++)
	{
	  cost = GetCost(uIndex, (*wIter)->GetIndex());
	  if (cost < minCost || minCost < 0)
	    {
	      *ppU = *uIter;
	      *ppW = *wIter;
	      minCost = cost;
	    }
	  DoCallback(trackerTotal, tracker++, OPTORD_GETBESTPAIR);
	}
    }

  if (NULL == *ppU || NULL == *ppW)
    __throw_cluster_ex(CLUSTEX_NULLARG, "sanity check failed");
}


void COptimalOrdering::Backtrack(CClusterNode* pCluster, CClusterNode* u, CClusterNode* w)
{
  // if u and w are on wrong sides of tree, swap children
  if(pCluster->IsRightLeaf(u))
    {
      pCluster->SwapLeaves();
    }

  BacktrackDescendLeft(pCluster->GetLeftChild(), u);
  BacktrackDescendRight(pCluster->GetRightChild(), w);
  DoCallback(m_iNodeCount, m_iNodeTracker++, OPTORD_BACKTRACK);
}

void COptimalOrdering::BacktrackDescendLeft(CClusterNode* pCluster, CClusterNode* u)
{
  // terminate recursion if this is a leaf node
  if (NULL == pCluster || pCluster->IsLeaf())
    {
      DoCallback(m_iNodeCount, m_iNodeTracker++, OPTORD_BACKTRACK);
      return;
    }

  // if leaf u is on the right, swap the left and right children
  if (pCluster->IsRightLeaf(u))
    {
      pCluster->SwapLeaves();
    }

  // Find optimal wLocal for u 
  CClusterNode* wLocal = GetBestWforU(pCluster, u);
  
  BacktrackDescendLeft(pCluster->GetLeftChild(), u);
  BacktrackDescendRight(pCluster->GetRightChild(), wLocal);
  DoCallback(m_iNodeCount, m_iNodeTracker++, OPTORD_BACKTRACK);

}

void COptimalOrdering::BacktrackDescendRight(CClusterNode* pCluster, CClusterNode* w)
{
  // terminate recursion if this is a leaf node
  if (NULL == pCluster || pCluster->IsLeaf())
    {
      DoCallback(m_iNodeCount, m_iNodeTracker++, OPTORD_BACKTRACK);
      return;
    }

  // if leaf w is on the left, swap the left and right children
  if (pCluster->IsLeftLeaf(w))
    {
      pCluster->SwapLeaves();
    }
  // Find optimal uLocal for w
  CClusterNode* uLocal = GetBestUforW(pCluster, w);
  
  BacktrackDescendLeft(pCluster->GetLeftChild(), uLocal);
  BacktrackDescendRight(pCluster->GetRightChild(), w);
  DoCallback(m_iNodeCount, m_iNodeTracker++, OPTORD_BACKTRACK);
}

COptimalOrdering::COptimalOrdering()
  : m_ppCosts(NULL), m_iLeafCount(0), m_pNode(NULL), m_pDistances(NULL),
    m_iNodeCount(0), m_iNodeTracker(0), m_iNodeTrackerMax(0)
{
}

COptimalOrdering::~COptimalOrdering()
{
  DeleteInternals();
}

double COptimalOrdering::GetCost(CClusterNode* u, CClusterNode* w)
{
  return m_ppCosts[u->GetIndex()][w->GetIndex()];
}

double COptimalOrdering::GetCost(int uIndex, int wIndex)
{
  return m_ppCosts[uIndex][wIndex];
}

void COptimalOrdering::SetCost(CClusterNode* u, CClusterNode* w, double cost)
{
  m_ppCosts[u->GetIndex()][w->GetIndex()] = cost;
}

void COptimalOrdering::SetCost(int uIndex, int wIndex, double cost)
{
  m_ppCosts[uIndex][wIndex] = cost;
}

CClusterNode* COptimalOrdering::GetBestWforU(CClusterNode* pCluster, CClusterNode* u)
{
  int uIndex = u->GetIndex();
  double cost, minCost = -1;
  CClusterNode* minW = NULL;

  // u is a left leaf, w is a right
  // for each right leaf of pCluster
  for (NodeVector::iterator wIter = pCluster->GetRightLeaves()->begin() ; 
       wIter != pCluster->GetRightLeaves()->end(); 
       wIter++)
    {
      // get its index
      // look up the cost
      cost = GetCost(uIndex, (*wIter)->GetIndex());

      // check whether it's less than the minimum
      if (cost < minCost || minCost < 0)
	{
	  minCost = cost;
	  minW = (*wIter);
	}
    }

  if (NULL == minW)
    __throw_cluster_ex(CLUSTEX_NULLARG, "minW is NULL");
  return minW;
}

CClusterNode* COptimalOrdering::GetBestUforW(CClusterNode* pCluster, CClusterNode* w)
{
  int wIndex = w->GetIndex();
  double cost, minCost = -1;
  CClusterNode* minU = NULL;

  // u is a left leaf, w is a right
  // for each left leaf of pCluster
  for (NodeVector::iterator uIter = pCluster->GetLeftLeaves()->begin() ; 
       uIter != pCluster->GetLeftLeaves()->end(); 
       uIter++)
    {
      // get its index
      // look wp the cost
      cost = GetCost((*uIter)->GetIndex(), wIndex);

      // check whether it's less than the minimwm
      if (cost < minCost || minCost < 0)
	{
	  minCost = cost;
	  minU = (*uIter);
	}
    }

  if (NULL == minU)
    __throw_cluster_ex(CLUSTEX_NULLARG, "minU is NULL");
  return minU;
}


uintmax_t COptimalOrdering::GetMemoryRequirement(CClusterNode* pCluster, CMatrix* pDistances)
{
  return 0;
}
