/*	$Id: fork.cpp 1780 2005-05-17 14:44:00Z jgressma $
 *
 *  Copyright 2005 University of Potsdam, Germany
 * 
 *	This file is part of Platypus. 
 *
 *  Platypus is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  Platypus is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with Platypus; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 *
 */

#include <mpi.h>
#include <distribution/mpi/mpi_platypus.h>

using namespace std;
  
namespace Platypus
{
  
  const std::string PlatypusMPI::type_("mpi");
  
  Platypus::Time myct;
  
  PlatypusMPI::PlatypusMPI()
    :	os_(0)
    ,	printer_(0)
    ,	program_(0)
    ,   serializer_(0)
    ,   stats_(0)
    ,   statistics_(0)
    ,   controller_(0)
    ,	requestedAnswerSets_(0)
    ,   threads_(0)
    ,   messages_(0)
    ,	suppressAnswerSets_(false)
    ,   hasDC_(false)
    ,	shutdown_(false)
    ,   verbose_(false)
    ,   id_(0)
    ,   processes_(0)
    ,   bufferSizePA_(0)
    ,   bufferSizeDC_(0)
    ,   starttime_(0)
    ,   endtime_(0)
  {}
  
  PlatypusMPI::~PlatypusMPI()
  {

    delete stats_;
    delete serializer_;
    delete statistics_;
    delete controller_;

  }
  
  
  DistributionBase* PlatypusMPI::create()
  {
    myct.start();
    return new PlatypusMPI;
  }
  
  void PlatypusMPI::initialize()
  {
    //MPI related start up
    MPI::Init();//_thread(MPI_THREAD_FUNNELED);//MULTIPLE);

    id_ = MPI::COMM_WORLD.Get_rank();
    processes_ = MPI::COMM_WORLD.Get_size();

    starttime_ = MPI::Wtime();

    //create a serializer for each process
    serializer_ = new MPISerializer(*program_);

    //set up stats collection vector
    stats_ = new std::vector<unsigned long>(NUM_STATS, 0);
    statistics_ = new MPIStatistics();

    //set up buffers for sending and receiving pa's and dc's
    DCBuffer_.resize(serializer_->bytesRequired());
    PABuffer_.resize(serializer_->bytesRequired());
    //std::cout << "Rank " << id_ << " online." << std::endl;
    
  }

  void PlatypusMPI::setup()
  {
    initialize();

    //set up and run the master process
    if(!id_){
      
      assert(program_);
      assert(printer_);

      statistics_->setup(processes_);
      controller_ = new MPIControl(*program_,
				   *statistics_,
      				   *printer_, 
      				   requestedAnswerSets_, 
      				   suppressAnswerSets_);
      controller_->setup();
      controller_->start();
      shutdown_ = true;
    }
  }
  
  /*
   *
   * BuilderDistributionCallback interface
   *
   */
  
  void PlatypusMPI::processCommandLine(int& argc, char** argv)
  {
 
    ProgramOptions::OptionGroup group;
    group.addOptions() 
      ("verbose,v", ProgramOptions::bool_switch()->defaultValue(false));
    
    ProgramOptions::OptionValues values;
    try
      {
	values.store(ProgramOptions::parseCommandLine(argc, argv, group, true));
      }
    catch(const exception& e)
      {
	MPI::COMM_WORLD.Abort(1);
      }

    bool verbose  = ProgramOptions::option_as<bool>(values, "verbose");
    verbose_ = static_cast<bool>(verbose);
 
  }
  
  void PlatypusMPI::program(const ProgramInterface& program)
  {
    program_ = &program;
  }
  
  void PlatypusMPI::output(NopStream& str)
  {
    os_ = &str;
  }
  
  void PlatypusMPI::options(const PlatypusOptions& values)
  {
    suppressAnswerSets_ = values.silent();
    requestedAnswerSets_ = values.requestedAnswerSets();
    threads_ = values.threads(); 
  }

  /*
   *
   * CoreDistributionCallback interface
   *
   */
  bool PlatypusMPI::shutdown()
  {
    if(id_)
      {
	if(shutdown_)
	  {
	    return true;
	  }
	//readded to test why ham_9 takes 2x as long with 1 worker 
	/*
	  else if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	  {
	  shutdown_ = true;
	  return true;
	  }
	*/
	else
	  return false;
      }
    
    return true;
    
  }
  
  bool PlatypusMPI::needDelegatableChoice()
  {
    if(id_)
      {

	//debug timing stuff
	//double startDcCheck = 0;
	//double endDcCheck = 0;

	
	if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	  {
	    shutdown_ = true;
	    return false;
	  }

	//this is kind of dumb... it is certain that
	//you do not have to use a non-blocking receive here
	if(MPI::COMM_WORLD.Iprobe(MASTER, DC_NEEDED_TAG))
	  {
	    
	    //debug timing stuff
	    //startDcCheck = MPI::Wtime();
	    //std::cout << "WORKER " << id_ << " initiated receiving DC_NEEDED_TAG in needDelegatableChoice()  " << startDcCheck - starttime_ << " seconds after start." << std::endl;
	    
	    char dummy = 0;
	    MPI::Request localRequest;
	    localRequest = MPI::COMM_WORLD.Irecv(&dummy,
						 1,
						 MPI::CHAR,
						 MASTER,
						 DC_NEEDED_TAG);
	    
	    
	    while(!localRequest.Test())
	      {
		if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
		  {
		    localRequest.Cancel();
		    shutdown_ = true;
		    return false;
		  }
	      }
	    
	    incMessagesReceived();

	    //debug timing stuff
	    //endDcCheck = MPI::Wtime();
	    //std::cout << "WORKER " << id_ << " completed receiving DC_NEEDED_TAG in needDelegatableChoice() in " << endDcCheck - startDcCheck << " seconds." << std::endl;

	    return true;
	  }
      }

    return false;
    
  }

  bool PlatypusMPI::delegate(const DelegatableChoice& dc) 
  {
    if(id_)
      {
	
	if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	  {
	    shutdown_ = true;
	    return false;
	  }

	sendDC(dc);
	return true;
      }
    return false;

  }
  
  void PlatypusMPI::fileDelegatableChoiceRequest()
  {
    if(id_)
      {
	
	
	if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	  {
	    shutdown_ = true;
	    return;
	  }
	

	hasDC_ = receiveDC();
      }
    return;
  }
  
  void PlatypusMPI::cancelDelegatableChoiceRequest()
  {
    if(id_)
      {

	if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	  {
	    shutdown_ = true;
	    return;
	  }
	

	//readded to test why ham_9 takes 2x as long with 1 worker
	/*
	  if(hasDC_)
	  {
	  hasDC_ = false;
	  sendDC(dc_);
	  }
	*/
      }
  }
  
  bool PlatypusMPI::hasDelegatableChoice()
  {
    if(id_)
      {
	  if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	    {
	      shutdown_ = true;
	      return false;
	    }

	  if(hasDC_)
	    return true;
	  
	  //readded to test why ham_9 takes 2x as long with 1 worker 
	  /*
	    if(receiveDC())
	    {
	    hasDC_ = true;
	    }
	  */
      }
    
    return hasDC_;
    
  }
  
  DelegatableChoice PlatypusMPI::delegatableChoice()
  {
    if(id_)
      {
	assert(hasDC_);
	hasDC_ = false;
	return dc_;
      }
    return dc_;
  }
  
  void PlatypusMPI::answerSet(const PartialAssignment& pa)
  {
    
    if(id_)
      {

	serializer_->serialize(PABuffer_, pa);
	localSlaveRequest_ = MPI::COMM_WORLD.Isend(&PABuffer_[0], 
						   PABuffer_.size(), 
						   MPI::UNSIGNED_LONG, 
						   MASTER, 
						   ANSWER_MESSAGE_TAG);
	
	while(!localSlaveRequest_.Test())
	  {
	    
	    if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	      {
		localSlaveRequest_.Cancel();
		shutdown_ = true;
		return;
	      }
	    
	  }	
	incMessagesSent();

      }
  }
  
  
  //helper functions
  bool PlatypusMPI::receiveDC()
  {

    //double initRequest = 0;
    //double compRequest = 0;

    if(id_)
      {
	
	//initRequest = MPI::Wtime();
	//std::cout << "WORKER " << id_ << " initiated choice request in receiveDC() " << initRequest - starttime_ << " seconds into run." << std::endl;

	unsigned long answersFound = answerSetsFound();
	localSlaveRequest_  = MPI::COMM_WORLD.Isend(&answersFound, 
						    1, 
						    MPI::UNSIGNED_LONG, 
						    MASTER, 
						    DC_REQUEST_TAG);

	while(!localSlaveRequest_.Test())
	  {
	    
	    if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	      {
		localSlaveRequest_.Cancel();
		shutdown_ = true;

		//compRequest = MPI::Wtime();
		//std::cout << "WORKER " << id_ << " received TERMINTATE_TAG in receiveDC()(Test()) " << compRequest - initRequest << " seconds after request initiation." << std::endl; 

		return false;
	      }
	    
	  }
	incMessagesSent();

	bool waiting = true;
	while(waiting)
	  {
	    //if(MPI::COMM_WORLD.Iprobe(MASTER, MPI::ANY_TAG, localSlaveStatus_))
	    MPI::COMM_WORLD.Probe(MASTER, MPI::ANY_TAG, localSlaveStatus_);
	    {
		
		size_t tag = localSlaveStatus_.Get_tag();
		
		if(tag == DC_TAG_FROM_CONTROLLER)
		  {
		    size_t count = localSlaveStatus_.Get_count(MPI::UNSIGNED_LONG);
		    //std::cout << "WORKER " << id_ << " count in receiveDC():" << count << std::endl;
		    DCBuffer_.resize(count);
		    localSlaveRequest_ = MPI::COMM_WORLD.Irecv(&DCBuffer_[0], 
							       //DCBuffer_.size(),
							       count,
							       MPI::UNSIGNED_LONG, 
							       MASTER, 
							       DC_TAG_FROM_CONTROLLER);


		    //do you really need this as non-blocking here?
		    while(!localSlaveRequest_.Test())
		      {
			
			if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
			  {
			    localSlaveRequest_.Cancel();
			    shutdown_ = true;
			    return false;
			  }
			
		      }
		    serializer_->deserialize(dc_, DCBuffer_);
		    //std::cout << "WORKER " << id_ << " received dc: " << dc_ << std::endl;
		    hasDC_ = true;
		    incMessagesReceived();

		    //compRequest = MPI::Wtime();
		    //std::cout << "WORKER " << id_ << " received DC_TAG_FROM_CONTROLLER in receiveDC() " << compRequest - initRequest << " seconds after request initiation." << std::endl;

		    return true;
		  }
		else if(tag == DC_NEEDED_TAG)
		  {
		    char dummy = 0;
		    localSlaveRequest_ = MPI::COMM_WORLD.Irecv(&dummy, 
							       1, 
							       MPI::CHAR, 
							       MASTER, 
							       DC_NEEDED_TAG);
		    
		    while(!localSlaveRequest_.Test())
		      {
			
			if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
			  {
			    localSlaveRequest_.Cancel();
			    shutdown_ = true;
			    return false;
			  }
			
		      }
		    incMessagesReceived();
		    incDroppedRequests();

		    //compRequest = MPI::Wtime();
		    //std::cout << "WORKER " << id_ << " received DC_NEEDED_TAG in receiveDC() " << compRequest - initRequest << " seconds after request initiation. Message dropped." << std::endl;
		  } 
		else if(tag == TERMINATE_TAG)
		  {
		    shutdown_ = true;

		    //compRequest = MPI::Wtime();
		    //std::cout << "WORKER " << id_ << " received TERMINATE_TAG in receiveDC() " << compRequest - initRequest << " seconds after request initiation." << std::endl;

		    return false;
		  }
	      }
	  }
      }

    return false;
  }
  
  //helper function
  bool PlatypusMPI::sendDC(const DelegatableChoice& dc)
  {
    if(id_)
      {
	
	//double initSend = 0;
	//double compSend = 0;
	//initSend = MPI::Wtime();
	//std::cout << "WORKER " << id_ << " initiated sending choice to master in sendDC() " << initSend - starttime_ << " seconds into run." << std::endl;

	serializer_->serialize(DCBuffer_, dc);
	//std::cout << "WORKER " << id_ << " count in sendDC(const DelegatableChoice& dc): " << DCBuffer_.size() << std::endl; 
	localSlaveRequest_ = MPI::COMM_WORLD.Isend(&DCBuffer_[0], 
						   DCBuffer_.size(), 
						   MPI::UNSIGNED_LONG, 
						   MASTER, 
						   DC_TAG_FROM_SLAVE);

	while(!localSlaveRequest_.Test())
	  {
	    if(MPI::COMM_WORLD.Iprobe(MASTER, TERMINATE_TAG))
	      {
		localSlaveRequest_.Cancel();
		shutdown_ = true;
		return false;
	      }
	    
	  }
	incMessagesSent();

	//compSend = MPI::Wtime();
	//std::cout << "WORKER " << id_ << " required " << compSend - initSend << " seconds in sendDC() to send choice to master." << std::endl;

	return true;
      }

    return false;
  }

  void PlatypusMPI::incExpanderInitializations(size_t inc)
  {
    (*stats_)[EXPANDER_INITS] += inc;
  }

  void PlatypusMPI::incConflicts(size_t inc)
  {
    (*stats_)[CONFLICTS] += inc;
  }

  void PlatypusMPI::incBacktracks(size_t inc)
  {
    (*stats_)[BACKTRACKS] += inc;
  }

  void PlatypusMPI::incAnswerSets(size_t inc)
  {
    (*stats_)[MODELS] += inc;
  }

  void PlatypusMPI::incThreadDelegations(size_t inc)
  {
    (*stats_)[DELEGATIONS] += inc;
  }
  
  void PlatypusMPI::incMessagesSent(size_t inc)
  {
    (*stats_)[MESSAGES_SENT] += inc;
  }

  void PlatypusMPI::incMessagesReceived(size_t inc)
  {
    (*stats_)[MESSAGES_RECEIVED] += inc;
  }

  void PlatypusMPI::incDroppedRequests(size_t inc)
  {
    (*stats_)[DROPPED_REQUESTS] += inc;
  }

  int PlatypusMPI::id() const
  {
    return id_;
  }  


  /*
   *
   * PlatypusAlgorithmDistributionCallback interface
   *
   */
  
  void PlatypusMPI::teardown()
  {

    char dummy = 0;
   
    if(id_)
      {
	//all the slaves should now receive their termination messages
	MPI::COMM_WORLD.Recv(&dummy, 
			     1, 
			     MPI::CHAR, 
			     MASTER, 
			     TERMINATE_TAG);

	incMessagesSent();
	incMessagesReceived();

	MPI::COMM_WORLD.Send(&(*stats_)[0], 
			     NUM_STATS, 
			     MPI::UNSIGNED_LONG,
			     MASTER,
			     TERMINATE_CONFIRMATION_TAG);

     }
    else
      {
	this->cleanup();
      }

    //myct.stop();
    //std::cout << "time to stop: " << (double)(myct.difference()*1.0)/(myct.frequency()*1.0) << " seconds" << std::endl;

    //MPI::COMM_WORLD.Barrier();

    if(!id_)
      endtime_ = MPI::Wtime();

    MPI::Finalize();

  }

  bool PlatypusMPI::print() const 
  { 
    return (!id_); 
  }
  
  std::ostream& PlatypusMPI::print(std::ostream& os) const
  {

    os << "\tDistributed via mpi." << "\n";
    os << "\tNodes: " << processes_ << "\n";
    os << "\tTotal messages sent: " << statistics_->messagesSent() << "\n";
    os << "\tTotal messages received: " << statistics_->messagesReceived() << "\n";
    os << "\tMPI time: " << endtime_ - starttime_ << "\n"; 
    os << "\n";

    if(verbose_)
      {
	for(unsigned i=0;i < processes_; i++)
	  {
	    if(i == MASTER)
	      {
		os << "Master statistics:" << "\n";
		os << "\tMessages sent by master: " << statistics_->indexMessagesSent(MASTER) << "\n";
		os << "\tMessages received by master: " << statistics_->indexMessagesReceived(MASTER) << "\n";
		os << "\tDelegatableChoices received: " << statistics_->workDelegationsToMaster() << "\n";
		os << "\tDelegatableChoices sent: " << statistics_->workDelegationsFromMaster() << "\n";
		os << "\tDelegatableChoice requests received: " << statistics_->workRequestsToMaster() << "\n";
		os << "\tDelegatableChoice requests sent: " << statistics_->workRequestsFromMaster() << "\n";
		os << "\tTotal work requests queued: " << statistics_->workDenials() << "\n";
		os << "\tChoice queue max size: " << statistics_->maxQueueSize() << "\n";
		//os << "\tRequests queue max size: " << statistics_->maxFiledWorkers() << "\n";
		
	      }
	    else
	      {
		os << "Worker " << i << " statistics:" << "\n";
		os << "\tMessages sent by worker " << i << ": " << statistics_->indexMessagesSent(i) << "\n";
		os << "\tMessages received by worker " << i << ": " << statistics_->indexMessagesReceived(i) << "\n";
		os << "\tAnswers generated by worker " << i << ": " << statistics_->indexModels(i) << "\n";
		if(threads_ > 1)
		  os << "\tThread delegations for worker " << i << ": " << statistics_->indexThreadDelegations(i) << "\n";
		os << "\tDelegatableChoices received from worker " << i << ": " << statistics_->indexWorkDelegationsToMaster(i) << "\n";
		os << "\tDelegatableChoices received by worker " << i << ": " << statistics_->indexWorkDelegationsFromMaster(i) << "\n";
		os << "\tTotal requests received from worker " << i << " to master: " << statistics_->indexWorkRequestsToMaster(i) << "\n";
		os << "\tTotal requests received by worker " << i << ": " << statistics_->indexWorkRequestsFromMaster(i) << "\n";
		os << "\tTotal requests queued for worker " << i << ": " << statistics_->indexWorkDenials(i) << "\n";
		os << "\tTotal requests dropped by worker " << i << ": " << statistics_->indexDroppedRequests(i) << "\n";
	      }
	    os << "\n";
	  }
      }

    return os;
  }


  size_t PlatypusMPI::expanderInitializations() const
  {
    if(id_)
      return (*stats_)[EXPANDER_INITS];
    else
      return statistics_->expanderInits();
  }

  size_t PlatypusMPI::conflicts() const
  {
    if(id_)
      return (*stats_)[CONFLICTS];
    else
      return statistics_->conflicts();
  }

  size_t PlatypusMPI::answerSetsFound() const
  {
    if(id_)
      return (*stats_)[MODELS];
    else
      return statistics_->models();
    //return statistics_->answers();
	
  }

  size_t PlatypusMPI::backtracks() const
  {
    if(id_)
      return (*stats_)[BACKTRACKS];
    else
      return statistics_->backtracks();
  }

  size_t PlatypusMPI::threads() const
  {
    return threads_;
  }

  size_t PlatypusMPI::threadDelegations() const
  {
    if(id_)
      return (*stats_)[DELEGATIONS];
    else
      return statistics_->threadDelegations();
  }

  size_t PlatypusMPI::messagesSent() const
  {
    if(id_)
      return (*stats_)[MESSAGES_SENT];
    else
      return statistics_->messagesSent();
  }

  size_t PlatypusMPI::messagesReceived() const
  {
    if(id_)
      return (*stats_)[MESSAGES_RECEIVED];
    else
      return statistics_->messagesReceived();
  }

  size_t PlatypusMPI::droppedRequests() const
  {
    if(id_)
      return (*stats_)[DROPPED_REQUESTS];
    else
      return statistics_->droppedRequests();
  }

  void PlatypusMPI::printer(AnswerSetPrinterBase& pr)
  {
    printer_ = &pr;
  }
  
  void PlatypusMPI::terminate()
  {

  }
  

  /*
   *
   * PlatypusMPI specific
   *
   */
  size_t PlatypusMPI::processes() const
  {
    return processes_;
  }
  
  void PlatypusMPI::disableAnswerSetPrinting()
  { 
    suppressAnswerSets_ = true; 
  }

  void PlatypusMPI::enableAnswerSetPrinting() 
  { 
    suppressAnswerSets_ = false; 
  }
  
  size_t PlatypusMPI::bytesRequiredPA() const
  {
    return serializer_->bytesRequired();
  }

  size_t PlatypusMPI::bytesRequiredDC() const
  {
    return serializer_->bytesRequired();
  }
  
  void PlatypusMPI::cleanup()
  {
    while(MPI::COMM_WORLD.Iprobe(MPI::ANY_SOURCE, DC_REQUEST_TAG, localSlaveStatus_))
      {
	unsigned source = localSlaveStatus_.Get_source();
	unsigned long dummy = 0;
	MPI::COMM_WORLD.Recv(&dummy, 
			     1, 
			     MPI::UNSIGNED_LONG, 
			     source, 
			     DC_REQUEST_TAG);
	
	statistics_->incMessagesReceived(MASTER);
      }
  }
  

}
