#include "allvars.h"
#include "proto.h"
#include <stdio.h>
#include <unistd.h>
#include <limits.h>


map_t     Me;           
MPI_Comm  COMM[HLEVELS];

char *LEVEL_NAMES[HLEVELS] = {"NUMA", "ISLAND", "myHOST", "HOSTS", "WORLD"};

MPI_Aint  win_host_master_size = 0;

MPI_Aint    win_ctrl_hostmaster_size; 
MPI_Win     win_ctrl_hostmaster;      
int         win_ctrl_hostmaster_disp; 
void       *win_ctrl_hostmaster_ptr;

MPI_Aint    win_hostmaster_size;
MPI_Win     win_hostmaster;
int         win_hostmaster_disp;
void       *win_hostmaster_ptr; 


int numa_build_mapping( int, int, MPI_Comm *, map_t *);
int numa_map_hostnames( MPI_Comm *, int, int, map_t *);
int get_cpu_id( void );
int compare_string_int_int( const void *, const void * );


int numa_init( int Rank, int Size, MPI_Comm *MYWORLD, map_t *Me )
{

  /* 
   * build up the numa hierarchy
   */
  numa_build_mapping( global_rank, size, MYWORLD, Me );
 
  /*
   * initialize the persistent shared windows
   */ 

  int SHMEMl = Me->SHMEMl;
  MPI_Info winfo;
  MPI_Info_create(&winfo);
  MPI_Info_set(winfo, "alloc_shared_noncontig", "true");

  // -----------------------------------
  // initialize the flow control windows
  // -----------------------------------
  Me->win_ctrl.size = sizeof(int);
  MPI_Win_allocate_shared(Me->win_ctrl.size, 1, winfo, *Me->COMM[SHMEMl],
			  &(Me->win_ctrl.ptr), &(Me->win_ctrl.win));

  MPI_Aint wsize = sizeof(int);
  MPI_Win_allocate_shared(wsize, 1, winfo, *Me->COMM[SHMEMl],
			  &win_ctrl_hostmaster_ptr, &win_ctrl_hostmaster);
  
  Me->scwins = (win_t*)malloc(Me->Ntasks[SHMEMl]*sizeof(win_t) );
  // get the addresses of all the windows from my siblings
  // at my shared-memory level
  //
  for( int t = 0; t < Me->Ntasks[SHMEMl]; t++ )
    {
      //if( t != Me->Rank[SHMEMl] )
	MPI_Win_shared_query( Me->win_ctrl.win, t, &(Me->scwins[t].size),
			      &(Me->scwins[t].disp), &(Me->scwins[t].ptr) );
    }

  if( Me->Rank[SHMEMl] != 0 )
    MPI_Win_shared_query( win_ctrl_hostmaster, 0, &(win_ctrl_hostmaster_size),
			  &win_ctrl_hostmaster_disp, &win_ctrl_hostmaster_ptr );


  return 0;
}


int numa_allocate_shared_windows(  map_t *me, MPI_Aint size, MPI_Aint host_size )
{

  int SHMEMl = me->SHMEMl;
  MPI_Info winfo;

  MPI_Info_create(&winfo);
  MPI_Info_set(winfo, "alloc_shared_noncontig", "true");

  // -----------------------------------
  // initialize the data windows
  // -----------------------------------
  MPI_Aint win_host_size;
  
  if( host_size == 0 )
    win_hostmaster_size = WIN_HOST_MASTER_SIZE_DFLT*1024*1024;
  else
    win_hostmaster_size = host_size;

  if( size == 0 )
    win_host_size  = WIN_HOST_SIZE_DFLT*1024*1024;
  else
    win_host_size  = size;


  me->win.size = win_host_size;
  MPI_Win_allocate_shared(me->win.size, 1, winfo, *me->COMM[SHMEMl], &(me->win.ptr), &(me->win.win));

  MPI_Aint wsize = ( me->Rank[SHMEMl] == 0 ? win_hostmaster_size : 0);
  MPI_Win_allocate_shared(wsize, 1, winfo, *me->COMM[SHMEMl], &win_hostmaster_ptr, &win_hostmaster);
  
  me->swins = (win_t*)malloc(me->Ntasks[SHMEMl]*sizeof(win_t) );
  me->swins[me->Rank[SHMEMl]] = me->win;
  //  me->swins = (win_t*)malloc(me->Ntasks[SHMEMl]*sizeof(win_t));
  // get the addresses of all the windows from my siblings
  // at my shared-memory level
  //
  for( int t = 0; t < me->Ntasks[SHMEMl]; t++ )
    if( t != me->Rank[SHMEMl] )
	MPI_Win_shared_query( me->win.win, t, &(me->swins[t].size), &(me->swins[t].disp), &(me->swins[t].ptr) );

  if( me->Rank[SHMEMl] != 0 )
    MPI_Win_shared_query( win_hostmaster, 0, &(win_hostmaster_size), &win_hostmaster_disp, &win_hostmaster_ptr );

  return 0;
}

int numa_shutdown( int Rank, int Size, MPI_Comm *MYWORLD, map_t *me )
{
  // free every shared memory and window
  //
  MPI_Win_free(&(me->win.win));

  // free all the structures if needed
  //
  free(me->Ranks_to_host);
  free(me->swins);

  // anything else
  //
  // ...

  return 0;
  
}

int numa_build_mapping( int Rank, int Size, MPI_Comm *MYWORLD, map_t *me )
{
  COMM[WORLD] = *MYWORLD;
  
  me->Ntasks[WORLD] = Size;
  me->Rank[WORLD]   = Rank;
  me->COMM[WORLD]   = &COMM[WORLD];

  me->mycpu = get_cpu_id();

  // --- find how many hosts we are running on;
  //     that is needed to build the communicator
  //     among the masters of each host
  //
  numa_map_hostnames( &COMM[WORLD], Rank, Size, me );


  me->MAXl = ( me->Nhosts > 1 ? HOSTS : myHOST );

  // --- create the communicator for each host
  //
  MPI_Comm_split( COMM[WORLD], me->myhost, me->Rank[WORLD], &COMM[myHOST]);
  MPI_Comm_size( COMM[myHOST], &Size );
  MPI_Comm_rank( COMM[myHOST], &Rank );
  
  me->COMM[myHOST] = &COMM[myHOST];
  me->Rank[myHOST]   = Rank;
  me->Ntasks[myHOST] = Size;

  // with the following gathering we build-up the mapping Ranks_to_hosts, so that
  // we know which host each mpi rank (meaning the original rank) belongs to
  //
  
  MPI_Allgather( &me->myhost, sizeof(me->myhost), MPI_BYTE,
		 me->Ranks_to_host, sizeof(me->myhost), MPI_BYTE, COMM[WORLD] );

  me -> Ranks_to_myhost = (int*)malloc(me->Ntasks[myHOST]*sizeof(int));
  MPI_Allgather( &global_rank, sizeof(global_rank), MPI_BYTE,
		 me->Ranks_to_myhost, sizeof(global_rank), MPI_BYTE, *me->COMM[myHOST]);
  


  // --- create the communicator for the
  //     masters of each host
  //
  int Im_host_master = ( me->Rank[myHOST] == 0 );
  MPI_Comm_split( COMM[WORLD], Im_host_master, me->Rank[WORLD], &COMM[HOSTS]);
  //
  // NOTE: by default, the Rank 0 in WORLD is also Rank 0 in HOSTS
  //
  if (Im_host_master)
  { 
    me->COMM[HOSTS] = &COMM[HOSTS];
    me->Ntasks[HOSTS] = me->Nhosts;
    MPI_Comm_rank( COMM[HOSTS], &(me->Rank[HOSTS]));
  }
  else 
  {
    me->COMM[HOSTS]  = NULL;
    me->Ntasks[HOSTS]  = 0;
    me->Rank[HOSTS]    = -1;
  }
  
  // --- create the communicator for the
  //     numa node
  //
  MPI_Comm_split_type( COMM[myHOST], MPI_COMM_TYPE_SHARED, me->Rank[myHOST], MPI_INFO_NULL, &COMM[NUMA]);
  me->COMM[NUMA] = &COMM[NUMA];
  MPI_Comm_size( COMM[NUMA], &(me->Ntasks[NUMA]));
  MPI_Comm_rank( COMM[NUMA], &(me->Rank[NUMA]));
  
  // check whether NUMA == myHOST and determine
  // the maximum level of shared memory in the
  // topology
  //
  if ( me->Ntasks[NUMA] == me->Ntasks[myHOST] )
    {
      // collapse levels from NUMA to myHOST
      //
      me->Ntasks[ISLAND] = me->Ntasks[NUMA];  // equating to NUMA as we know the rank better via MPI_SHARED
      me->Rank[ISLAND]   = me->Rank[NUMA];
      me->COMM[ISLAND]   = me->COMM[NUMA];
      
      me->Rank[myHOST]   = me->Rank[NUMA];
      me->COMM[myHOST]   = me->COMM[NUMA];
      me->SHMEMl         = myHOST;
    }
  else
    {
      // actually we do not care for this case
      // at this moment
      printf(">>> It seems that rank %d belongs to a node for which "
	     "    the node topology does not coincide \n", Rank );
      me->SHMEMl = NUMA;
    }

  int check_SHMEM_level = 1;
  int globalcheck_SHMEM_level;
  int globalmax_SHMEM_level;
  MPI_Allreduce( &(me->SHMEMl), &globalmax_SHMEM_level, 1, MPI_INT, MPI_MAX, *MYWORLD );

  check_SHMEM_level = ( (me->SHMEMl == myHOST) && (globalmax_SHMEM_level == me->SHMEMl) );
  
  MPI_Allreduce( &check_SHMEM_level, &globalcheck_SHMEM_level, 1, MPI_INT, MPI_MAX, *MYWORLD );
  
  if( globalcheck_SHMEM_level < 1 )
    {
      if( Rank == 0 ) {
	printf("There was an error in determining the topology hierarchy, "
	       "SHMEM level is different for different MPI tasks\n");
	return -1; }
    }  
  
  return 0;  
}


int numa_map_hostnames( MPI_Comm *MY_WORLD,   // the communicator to refer to
			int Rank,              // the initial rank of the calling process in MYWORLD
			int Ntasks,            // the number of tasks in MY_WORLD
			map_t *me)             // address of the info structure for the calling task

{
  // --------------------------------------------------
  // --- init some global vars
  me -> Ranks_to_host = (int*)malloc(Ntasks*sizeof(int));
  me -> Nhosts = 0;
  me -> myhost = -1;

  // --------------------------------------------------
  // --- find how many hosts we are using
  

  char myhostname[HOST_NAME_MAX+1];
  gethostname( myhostname, HOST_NAME_MAX+1 );


  // determine how much space to book for hostnames
  int myhostlen = strlen(myhostname)+1;
  int maxhostlen = 0;
  MPI_Allreduce ( &myhostlen, &maxhostlen, 1, MPI_INT, MPI_MAX, *MY_WORLD );

  // collect hostnames
  //
  typedef struct {
    char hostname[maxhostlen];
    int rank;
  } hostname_rank_t;
      
  hostname_rank_t mydata;
  hostname_rank_t *alldata = (hostname_rank_t*)calloc( Ntasks, sizeof(hostname_rank_t) );

  mydata.rank = Rank;  
  sprintf( mydata.hostname, "%s", myhostname);
  
  MPI_Allgather( &mydata, sizeof(hostname_rank_t), MPI_BYTE, alldata, sizeof(hostname_rank_t), MPI_BYTE, *MY_WORLD );
   
  // sort the hostnames
  //       1) set the lenght of string for comparison
  int dummy = maxhostlen;
  compare_string_int_int( NULL, &dummy );


  //       2) actually sort
  qsort( alldata, Ntasks, sizeof(hostname_rank_t), compare_string_int_int );
  // now the array alldata is sorted by hostname, and inside each hostname the processes
  // running on each host are sorted by their node, and for each node they are sorted
  // by ht.
  // As a direct consequence, the running index on the alldata array can be considered
  // as the new global rank of each process
  
  // --- count how many diverse hosts we have, and register each rank to its host, so that
  //      we can alway find all the tasks with their original rank

      
  char *prev = alldata[0].hostname;
  for ( int R = 0; R < Ntasks; R++ )
  {	
    if ( strcmp(alldata[R].hostname, prev) != 0 ) {      
      me->Nhosts++; prev = alldata[R].hostname; }

    if ( alldata[R].rank == Rank )        // it's me
      me->myhost = me->Nhosts;            // remember my host
  }
  me->Nhosts++;

  free( alldata );

  return me->Nhosts;
}



int compare_string_int_int( const void *A, const void *B )
// used to sort structures made as
// { char *s;
//   int b;
//   ... }
// The sorting is hierarchical by *s first, then b
//   if necessary
// The length of *s is set by calling
//   compare_string_int_int( NULL, len )
// before to use this routine in qsort-like calls
{
  static int str_len = 0;
  if ( A == NULL )
    {
      str_len = *(int*)B + 1;
      return 0;
    }

  // we do not use strncmp because str_len=0,
  // i.e. using this function without initializing it,
  // can be used to have a sorting only on
  // strings
  int order = strcmp( (char*)A, (char*)B );
  
  if ( str_len && (!order) )
    {
      int a = *(int*)((char*)A + str_len);
      int b = *(int*)((char*)B + str_len);
      order = a - b;
      if( !order )
	{
	  int a = *((int*)((char*)A + str_len)+1);
	  int b = *((int*)((char*)B + str_len)+1);
	  order = a - b;
	}
    }
  
  return order;
}


#define CPU_ID_ENTRY_IN_PROCSTAT 39

int read_proc__self_stat( int, int * );

int get_cpu_id( void )
{
#if defined(_GNU_SOURCE)                              // GNU SOURCE ------------
  
  return  sched_getcpu( );

#else

#ifdef SYS_getcpu                                     //     direct sys call ---
  
  int cpuid;
  if ( syscall( SYS_getcpu, &cpuid, NULL, NULL ) == -1 )
    return -1;
  else
    return cpuid;
  
#else      

  int val;
  if ( read_proc__self_stat( CPU_ID_ENTRY_IN_PROCSTAT, &val ) == -1 )
    return -1;

  return (int)val;

#endif                                                // -----------------------
#endif

}



int read_proc__self_stat( int field, int *ret_val )
/*
  Other interesting fields:

  pid      : 0
  father   : 1
  utime    : 13
  cutime   : 14
  nthreads : 18
  rss      : 22
  cpuid    : 39

  read man /proc page for fully detailed infos
 */
{
  // not used, just mnemonic
  // char *table[ 52 ] = { [0]="pid", [1]="father", [13]="utime", [14]="cutime", [18]="nthreads", [22]="rss", [38]="cpuid"};

  *ret_val = 0;

  FILE *file = fopen( "/proc/self/stat", "r" );
  if (file == NULL )
    return -1;

  char   *line = NULL;
  int     ret;
  size_t  len;
  ret = getline( &line, &len, file );
  fclose(file);

  if( ret == -1 )
    return -1;

  char *savetoken = line;
  char *token = strtok_r( line, " ", &savetoken);
  --field;
  do { token = strtok_r( NULL, " ", &savetoken); field--; } while( field );

  *ret_val = atoi(token);

  free(line);

  return 0;
}