/*******************************************************************************
* Copyright (C) 2014 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file ExchangeHalo.cpp

 HPCG routine
 */

// Compile this routine only if running with MPI
#ifndef HPCG_NO_MPI
#include <mpi.h>
#include "Geometry.hpp"
#include "ExchangeHalo.hpp"
#include <cstdlib>

#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif

/*!
  Communicates data that is at the border of the part of the domain assigned to this processor.

  @param[in]    A The known system matrix
  @param[inout] x On entry: the local vector entries followed by entries to be communicated; on exit: the vector with non-local entries updated by other processors
 */
void ExchangeHalo(const SparseMatrix & A, Vector & x) {
  // Extract Matrix pieces
  if ( A.geom->size > 1 )
  {
#ifdef HPCG_LOCAL_LONG_LONG
  // using MPI_ISend and MPI_IRecv since MPI_Alltoallv cannot handle long long arguments

  local_int_t localNumberOfRows = A.localNumberOfRows;
  int num_neighbors = A.numberOfSendNeighbors;
  local_int_t * receiveLength = A.receiveLength;
  local_int_t * sendLength = A.sendLength;
  int * neighbors = A.neighbors;
  double * sendBuffer = A.sendBuffer_h;
  local_int_t totalToBeSent = A.totalToBeSent;
  local_int_t * elementsToSend = A.elementsToSend_h;

  double * const xv = x.values;

  int MPI_MY_TAG = 99;
  MPI_Request * request = new MPI_Request[num_neighbors];

  // Externals are at end of locals
  double * x_external = (double *) xv + localNumberOfRows;

  // Post receives first
  for (int i = 0; i < num_neighbors; i++) {
    int n_recv = receiveLength[i];
    MPI_Irecv(x_external, n_recv, MPI_DOUBLE, neighbors[i], MPI_MY_TAG, MPI_COMM_WORLD, request+i);
    x_external += n_recv;
  }

  // Fill up send buffer
  for (local_int_t i=0; i<totalToBeSent; i++) sendBuffer[i] = xv[elementsToSend[i]];

  // Send to each neighbor
  for (int i = 0; i < num_neighbors; i++) {
    int n_send = sendLength[i];
    MPI_Send(sendBuffer, n_send, MPI_DOUBLE, neighbors[i], MPI_MY_TAG, MPI_COMM_WORLD);
    sendBuffer += n_send;
  }

  // Complete the reads issued above
  MPI_Status status;
  for (int i = 0; i < num_neighbors; i++) {
    if ( MPI_Wait(request+i, &status) ) {
      exit(-1); // TODO: have better error exit
    }
  }

  delete [] request;
#else
      local_int_t localNumberOfRows = A.localNumberOfRows;
      double * sendBuffer = A.sendBuffer_h;
      local_int_t totalToBeSent = A.totalToBeSent;
      local_int_t * elementsToSend = A.elementsToSend_h;

      double * const xv = x.values;
      double * x_external = (double *) xv + localNumberOfRows;

      #pragma ivdep
      for (local_int_t i=0; i<totalToBeSent; i++) sendBuffer[i] = xv[elementsToSend[i]];

      MPI_Alltoallv( sendBuffer, A.scounts, A.sdispls, MPI_DOUBLE, x_external, A.rcounts, A.rdispls, MPI_DOUBLE, MPI_COMM_WORLD);
#endif
  }
  return;
}


sycl::event ExchangeHalo(const SparseMatrix &A, Vector &x, sycl::queue &main_queue,
                         const std::vector<sycl::event> &deps) {
  // Extract Matrix pieces
  sycl::event last_ev;
  if ( A.geom->size > 1 )
  {
#ifdef HPCG_TEST_NO_HALO_EXCHANGE
    last_ev = main_queue.ext_oneapi_submit_barrier(deps);
    return last_ev;
#endif

#ifdef HPCG_LOCAL_LONG_LONG
  // using MPI_ISend and MPI_IRecv since MPI_Alltoallv cannot handle long long arguments

  local_int_t localNumberOfRows = A.localNumberOfRows;
  int num_neighbors = A.numberOfSendNeighbors;
  local_int_t * receiveLength = A.receiveLength;
  local_int_t * sendLength = A.sendLength;
  int * neighbors = A.neighbors;
  double * sendBuffer_h = A.sendBuffer_h;
  local_int_t totalToBeSent = A.totalToBeSent;
  local_int_t * elementsToSend_d = A.elementsToSend_d;

  double * const xv = x.values;

  int MPI_MY_TAG = 99;
  MPI_Request * request = new MPI_Request[num_neighbors];

  // Externals are at end of locals
  double * x_external = (double *) xv + localNumberOfRows;

   // Post receives first
  for (int i = 0; i < num_neighbors; i++) {
    int n_recv = receiveLength[i];
    MPI_Irecv(x_external, n_recv, MPI_DOUBLE, neighbors[i], MPI_MY_TAG, MPI_COMM_WORLD, request+i);
    x_external += n_recv;
  }

  // Fill up send buffer
  for (local_int_t i=0; i<totalToBeSent; i++) sendBuffer_h[i] = xv[elementsToSend_d[i]];

  // Send to each neighbor
  for (int i = 0; i < num_neighbors; i++) {
    int n_send = sendLength[i];
    MPI_Send(sendBuffer_h, n_send, MPI_DOUBLE, neighbors[i], MPI_MY_TAG, MPI_COMM_WORLD);
    sendBuffer_h += n_send;
  }

  // Complete the reads issued above
  MPI_Status status;
  for (int i = 0; i < num_neighbors; i++) {
    if ( MPI_Wait(request+i, &status) ) {
      exit(-1); // TODO: have better error exit
    }
  }

  delete [] request;
#else
      local_int_t localNumberOfRows = A.localNumberOfRows;
      double * sendBuffer_d = A.sendBuffer; // device
      double * sendBuffer_h = A.sendBuffer_h;
      const local_int_t totalToBeSent = A.totalToBeSent;
      local_int_t * elementsToSend_d = A.elementsToSend_d;

      double * const xv_d = x.values;

      double * x_external_d = static_cast<double *>(xv_d);
      x_external_d += localNumberOfRows;
      struct optData *optData = (struct optData*)A.optimizationData;

#ifdef HPCG_USE_MPI_OFFLOAD
      (void)sendBuffer_h; // to avoid compiler warning
      (void)optData; // to avoid compiler warning
      BEGIN_PROFILE("ExchangeHalo:fill_buffer");
      auto fill_ev = main_queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(deps);
        auto kernel = [=] (sycl::item<1> item) {
          local_int_t row = item.get_id(0);
          sendBuffer_d[row] = xv_d[elementsToSend_d[row]]; // elementsToSend are permuted in OptimizeProblem already
        };
        cgh.parallel_for<class ExchangeHaloClass>(sycl::range<1>(totalToBeSent), kernel);
      });
      END_PROFILE_WAIT("ExchangeHalo:fill_buffer", fill_ev);

      BEGIN_PROFILE("ExchangeHalo:MPI_Alltoallv");
      last_ev = main_queue.submit([&](sycl::handler &cgh) {
          local_int_t *scounts_h = A.scounts;
          local_int_t *sdispls_h = A.sdispls;
          local_int_t *rcounts_h = A.rcounts;
          local_int_t *rdispls_h = A.rdispls;
          cgh.depends_on(fill_ev);
          auto kernel = [=]() {
              MPI_Alltoallv( sendBuffer_d, scounts_h, sdispls_h, MPI_DOUBLE, x_external_d,
                             rcounts_h, rdispls_h, MPI_DOUBLE, MPI_COMM_WORLD);
          };
          cgh.host_task(kernel);
      });
      END_PROFILE_WAIT("ExchangeHalo:MPI_Alltoallv", last_ev);

#else // not HPCG_USE_MPI_OFFLOAD

      double * x_external_h = x_external_d;
      if (optData->halo_host_vector != nullptr) {
          // x_external is generally device memory, which is fine on 1 process,
          // but we need host memory for MPI communication
          x_external_h = optData->halo_host_vector;
      }

      BEGIN_PROFILE("ExchangeHalo:fill_buffer");
      auto ev = main_queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(deps);
        auto kernel = [=] (sycl::item<1> item) {
          local_int_t row = item.get_id(0);
          sendBuffer_d[row] = xv_d[elementsToSend_d[row]]; // elementsToSend are permuted in OptimizeProblem already
        };
        cgh.parallel_for<class ExchangeHaloClass>(sycl::range<1>(totalToBeSent), kernel);
      });
      auto copy_ev = main_queue.memcpy(sendBuffer_h, sendBuffer_d, sizeof(double) * (totalToBeSent), ev);
      END_PROFILE_WAIT("ExchangeHalo:fill_buffer", copy_ev);

      BEGIN_PROFILE("ExchangeHalo:MPI_Alltoallv");
      auto mpi_ev = main_queue.submit([&](sycl::handler &cgh) {
          local_int_t *scounts_h = A.scounts;
          local_int_t *sdispls_h = A.sdispls;
          local_int_t *rcounts_h = A.rcounts;
          local_int_t *rdispls_h = A.rdispls;
          cgh.depends_on(copy_ev);
          auto kernel = [=]() {
              MPI_Alltoallv( sendBuffer_h, scounts_h, sdispls_h, MPI_DOUBLE, x_external_h,
                             rcounts_h, rdispls_h, MPI_DOUBLE, MPI_COMM_WORLD);
          };
          cgh.host_task(kernel);
      });
      END_PROFILE_WAIT("ExchangeHalo:MPI_Alltoallv", mpi_ev);

      BEGIN_PROFILE("ExchangeHalo:memcpy");
      if (optData->halo_host_vector != nullptr) {
          last_ev = main_queue.memcpy(x_external_d, x_external_h,
                                      sizeof(double) * (A.localNumberOfColumns - localNumberOfRows), mpi_ev);
      }
      else {
        last_ev = mpi_ev;
      }
      END_PROFILE_WAIT("ExchangeHalo:memcpy", last_ev);
#endif // HPCG_USE_MPI_OFFLOAD

      return last_ev;
#endif
  }

  last_ev = main_queue.ext_oneapi_submit_barrier(deps);
  return last_ev;
}
#endif
// ifndef HPCG_NO_MPI
