#include <mpi.h>
#include <omp.h>
#include <stdio.h>
#include "../common/common.h"
#include "../tree/tree.h"
#include "../adp/adp.h"
#include <argp.h>


#ifdef AMONRA
    #pragma message "Hi, you are on amonra"
    #define OUT_CLUSTER_ASSIGN "/beegfs/ftomba/phd/results/final_assignment.npy"
    #define OUT_HALO_FLAGS     "/beegfs/ftomba/phd/results/halo_flags.npy"
    #define OUT_DATA           "/beegfs/ftomba/phd/results/ordered_data.npy"
#endif

#ifdef LEONARDO
    #define OUT_CLUSTER_ASSIGN "/leonardo_scratch/large/userexternal/ftomba00/out_dadp/final_assignment.npy"
    #define OUT_DATA           "/leonardo_scratch/large/userexternal/ftomba00/out_dadp/ordered_data.npy"
#endif

#ifdef LUMI
    #define OUT_CLUSTER_ASSIGN "~/scratch_dadp/out_dadp/final_assignment.npy"
    #define OUT_DATA           "~/scratch_dadp/out_dadp/ordered_data.npy"
#endif

#ifndef  OUT_CLUSTER_ASSIGN
    #define OUT_CLUSTER_ASSIGN "final_assignment.npy"
    #define OUT_DATA           "ordered_data.npy"
#endif


#ifdef THREAD_FUNNELED
    #define THREAD_LEVEL MPI_THREAD_FUNNELED
#else
    #define THREAD_LEVEL MPI_THREAD_MULTIPLE
#endif

int main(int argc, char** argv) {
    #if defined (_OPENMP)
        int mpi_provided_thread_level;
        MPI_Init_thread( &argc, &argv, THREAD_LEVEL, &mpi_provided_thread_level);
        if ( mpi_provided_thread_level < THREAD_LEVEL ) 
        {
            switch(THREAD_LEVEL)
            {
                case MPI_THREAD_FUNNELED:
                    printf("a problem arise when asking for MPI_THREAD_FUNNELED level\n");
                    MPI_Finalize();
                    exit( 1 );
                    break;
                case MPI_THREAD_SERIALIZED:
                    printf("a problem arise when asking for MPI_THREAD_SERIALIZED level\n");
                    MPI_Finalize();
                    exit( 1 );
                    break;
                case MPI_THREAD_MULTIPLE:
                    printf("a problem arise when asking for MPI_THREAD_MULTIPLE level\n");
                    MPI_Finalize();
                    exit( 1 );
                    break;
            }
        }
    #else
        MPI_Init(NULL, NULL);
    #endif


    char processor_name[MPI_MAX_PROCESSOR_NAME];
    int name_len;
    MPI_Get_processor_name(processor_name, &name_len);
	
	global_context_t ctx;

	ctx.mpi_communicator = MPI_COMM_WORLD;
	get_context(&ctx);

    #if defined (_OPENMP)
        mpi_printf(&ctx,"Running Hybrid (Openmp + MPI) code\n");
    #else
        mpi_printf(&ctx,"Running pure MPI code\n");
    #endif

    #if defined (THREAD_FUNNELED)
        mpi_printf(&ctx,"/!\\ Code built with MPI_THREAD_FUNNELED level\n");
    #else
        mpi_printf(&ctx,"/!\\ Code built with MPI_THREAD_MULTIPLE level\n");
    #endif

	/*
	 * Mock reading some files, one for each processor
	 */

	int d = 5;
	
	float_t* data;
    

	/*
	 * Generate a random matrix of lenght of some kind
	 */
	if(ctx.mpi_rank == 0)
	{
		simulate_master_read_and_scatter(5, 1000000, &ctx);	
	}
	else
	{
		simulate_master_read_and_scatter(0, 0, &ctx);	
	}

	//free(data);
	free_context(&ctx);
    MPI_Finalize();
}



void simulate_master_read_and_scatter(int dims, size_t n, global_context_t *ctx) 
{
    /* TODO
     *
     */
    float_t *data;
    TIME_DEF
    double elapsed_time;

    float_t z = 3;
    int halo = MY_TRUE;
    float_t tol = 0.002;
    int k = 300;

    if(I_AM_MASTER && ctx -> world_size <= 6)
    {
        test_file_path(OUT_DATA);
        test_file_path(OUT_CLUSTER_ASSIGN);
    }
    else
    {
        test_distributed_file_path(ctx, OUT_DATA);
        test_distributed_file_path(ctx, OUT_CLUSTER_ASSIGN);
    }
    

    TIME_START;
    if (ctx->mpi_rank == 0) 
    {
        //data = read_data_file(ctx, "../norm_data/50_blobs_more_var.npy", MY_TRUE);
        //ctx->dims = 2;
        //data = read_data_file(ctx, "../norm_data/blobs_small.npy", MY_FALSE);
        //data = read_data_file(ctx, "../norm_data/blobs_small.npy", MY_FALSE);
        // std_g0163178_Me14_091_0000
    
        // 100M points
        // 2D
        // std_g2980844_091_0000
        //data = read_data_file(ctx,"../norm_data/huge_blobs.npy",MY_FALSE);
        // 2B points
        // data = read_data_file(ctx,"../norm_data/very_huge_blobs.npy",MY_FALSE);
        // data = read_data_file(ctx,"../norm_data/hd_blobs.npy",5,MY_FALSE);
        
        //1B points
        // data = read_data_file(ctx,"../norm_data/eds_box_acc_normalized",5,MY_FALSE);
        // data = read_data_file(ctx,"../norm_data/eds_box_6d",6,MY_FALSE);

        // 190M points
        // std_g2980844_091_0000
        // data = read_data_file(ctx,"../norm_data/std_g2980844_091_0000",5,MY_TRUE);
        
        /* 1M points ca.*/
        data = read_data_file(ctx,"../norm_data/std_LR_091_0001",5,MY_TRUE);

        /* BOX */
        // data = read_data_file(ctx,"../norm_data/std_Box_256_30_092_0000",MY_TRUE);

        /* 8M points */
        
        // data = read_data_file(ctx,"../norm_data/std_g0144846_Me14_091_0001",5,MY_TRUE);

        //88M 
        // data = read_data_file(ctx,"../norm_data/std_g5503149_091_0000",MY_TRUE);

        //
        //34 M
        //data = read_data_file(ctx,"../norm_data/std_g1212639_091_0001",MY_TRUE);
        
        //for weak scalability 
        ctx->n_points = ctx->n_points / 4;
        //ctx->n_points = (ctx->n_points / 32) * ctx -> world_size;

        get_dataset_diagnostics(ctx, data);

    }
    
    /* communicate the total number of points*/
    MPI_Bcast(&(ctx->dims), 1, MPI_UINT32_T, 0, ctx->mpi_communicator);
    MPI_Bcast(&(ctx->n_points), 1, MPI_UINT64_T, 0, ctx->mpi_communicator);

    /* compute the number of elements to recieve for each processor */
    idx_t *send_counts = (idx_t *)MY_MALLOC(ctx->world_size * sizeof(idx_t));
    idx_t *displacements = (idx_t *)MY_MALLOC(ctx->world_size * sizeof(idx_t));

    displacements[0] = 0;
    send_counts[0] = ctx->n_points / ctx->world_size;
    send_counts[0] += (ctx->n_points % ctx->world_size) > 0 ? 1 : 0;
    send_counts[0] = send_counts[0] * ctx->dims;

    for (int p = 1; p < ctx->world_size; ++p) 
    {
        send_counts[p] = (ctx->n_points / ctx->world_size);
        send_counts[p] += (ctx->n_points % ctx->world_size) > p ? 1 : 0;
        send_counts[p] = send_counts[p] * ctx->dims;
        displacements[p] = displacements[p - 1] + send_counts[p - 1];
    }


    ctx->local_n_points = send_counts[ctx->mpi_rank] / ctx->dims;

    float_t *pvt_data = (float_t *)MY_MALLOC(send_counts[ctx->mpi_rank] * sizeof(float_t));

    uint64_t default_msg_len = 10000000; //bytes

    if(I_AM_MASTER)
    {
        memcpy(pvt_data, data, ctx -> dims * ctx -> local_n_points * sizeof(float_t));
        int already_sent_points = 0;
        for(int i = 1; i < ctx -> world_size; ++i)
        {
            already_sent_points = 0;
            while(already_sent_points < send_counts[i])
            {
                int count_send = MIN(default_msg_len, send_counts[i] - already_sent_points); 
                MPI_Send(data + displacements[i] + already_sent_points, count_send, MPI_MY_FLOAT, i, ctx -> mpi_rank, ctx -> mpi_communicator);
                already_sent_points += count_send;
                //DB_PRINT("[RANK 0] has sent to rank %d %d elements out of %lu\n",i, already_sent_points, send_counts[i]);
            }
            //DB_PRINT("------------------------------------------------\n");
        }
    }
    else
    {
        int already_recvd_points = 0;
        while(already_recvd_points < send_counts[ctx -> mpi_rank])
        {
            MPI_Status status;
            MPI_Probe(0, MPI_ANY_TAG, ctx -> mpi_communicator, &status);

            MPI_Request request;
            int count_recv; 
            int source = status.MPI_SOURCE;
            MPI_Get_count(&status, MPI_MY_FLOAT, &count_recv);

            MPI_Recv(pvt_data + already_recvd_points, count_recv, MPI_MY_FLOAT, source, MPI_ANY_TAG, ctx -> mpi_communicator, MPI_STATUS_IGNORE);
            already_recvd_points += count_recv;
        }
    }

    elapsed_time = TIME_STOP;
    LOG_WRITE("Importing file ad scattering", elapsed_time);

    if (I_AM_MASTER) free(data);

    ctx->local_data = pvt_data;

    int k_local  = 20;
    int k_global = 20;

    uint64_t *global_bin_counts_int = (uint64_t *)MY_MALLOC(k_global * sizeof(uint64_t));

    pointset_t original_ps;
    original_ps.data = ctx->local_data;
    original_ps.dims = ctx->dims;
    original_ps.n_points = ctx->local_n_points;
    original_ps.lb_box = (float_t*)MY_MALLOC(ctx -> dims * sizeof(float_t));
    original_ps.ub_box = (float_t*)MY_MALLOC(ctx -> dims * sizeof(float_t));


    top_kdtree_t tree;
    TIME_START;
    top_tree_init(ctx, &tree);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Initializing global kdtree", elapsed_time);

    TIME_START;
    build_top_kdtree(ctx, &original_ps, &tree, k_global, tol);
    exchange_points(ctx, &tree);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Top kdtree build and domain decomposition", elapsed_time);

    TIME_START;
    kdtree_v2 local_tree;
    kdtree_v2_init( &local_tree, ctx -> local_data, ctx -> local_n_points, (unsigned int)ctx -> dims);

    datapoint_info_t* dp_info = (datapoint_info_t*)MY_MALLOC(ctx -> local_n_points * sizeof(datapoint_info_t));            
    for(uint64_t i = 0; i < ctx -> local_n_points; ++i)
    {
        dp_info[i].ngbh.data = NULL;
        dp_info[i].ngbh.N = 0;
        dp_info[i].ngbh.count = 0;
        dp_info[i].g = 0.f;
        dp_info[i].log_rho = 0.f;
        dp_info[i].log_rho_c = 0.f;
        dp_info[i].log_rho_err = 0.f;
        dp_info[i].array_idx = -1;
        dp_info[i].kstar = -1;
        dp_info[i].is_center = -1;
        dp_info[i].cluster_idx = -1;
        //dp_info[i].halo_flag = 0;
    }
    ctx -> local_datapoints = dp_info;

    build_local_tree(ctx, &local_tree);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Local trees init and build", elapsed_time);

    TIME_START;
    MPI_DB_PRINT("----- Performing ngbh search -----\n");
    MPI_Barrier(ctx -> mpi_communicator);

    mpi_ngbh_search(ctx, dp_info, &tree, &local_tree, ctx -> local_data, k);

    MPI_Barrier(ctx -> mpi_communicator);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Total time for all knn search", elapsed_time)



    TIME_START;
    //float_t id = id_estimate(ctx, dp_info, ctx -> local_n_points, 0.9, MY_FALSE);
    float_t id = compute_ID_two_NN_ML(ctx, dp_info, ctx -> local_n_points, MY_FALSE);
    elapsed_time = TIME_STOP;
    LOG_WRITE("ID estimate", elapsed_time)

    MPI_DB_PRINT("ID %lf \n",id);

    TIME_START;


    compute_density_kstarnn_rma_v2(ctx, id, MY_FALSE);
    compute_correction(ctx, z);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Density estimate", elapsed_time)

    TIME_START;
    clusters_t clusters = Heuristic1(ctx);
    elapsed_time = TIME_STOP;
    LOG_WRITE("H1", elapsed_time)

    TIME_START;
    clusters_allocate(&clusters, 1);
    Heuristic2(ctx, &clusters);
    elapsed_time = TIME_STOP;
    LOG_WRITE("H2", elapsed_time)


    TIME_START;
    Heuristic3(ctx, &clusters, z, halo);
    elapsed_time = TIME_STOP;
    LOG_WRITE("H3", elapsed_time)

    
    TIME_START;
    int* cl = (int*)MY_MALLOC(ctx -> local_n_points * sizeof(int));
    for(int i = 0; i < ctx -> local_n_points; ++i) cl[i] = ctx -> local_datapoints[i].cluster_idx;

    if(ctx -> world_size <= 6)
    {
        big_ordered_buffer_to_file(ctx, cl, sizeof(int), ctx -> local_n_points, OUT_CLUSTER_ASSIGN);
        big_ordered_buffer_to_file(ctx, ctx -> local_data, sizeof(double), ctx -> local_n_points * ctx -> dims, OUT_DATA);

        free(cl);

    }
    else
    {
        distributed_buffer_to_file(ctx, cl, sizeof(int), ctx -> local_n_points, OUT_CLUSTER_ASSIGN);
        distributed_buffer_to_file(ctx, ctx -> local_data, sizeof(double), ctx -> local_n_points * ctx -> dims, OUT_DATA);

        free(cl);

    }

    elapsed_time = TIME_STOP;
    LOG_WRITE("Write results to file", elapsed_time);
    
    
    top_tree_free(ctx, &tree);
    kdtree_v2_free(&local_tree);
    //clusters_free(&clusters);

    free(send_counts);
    free(displacements);
    //free(dp_info);
    
    original_ps.data = NULL;
    free_pointset(&original_ps);
    free(global_bin_counts_int);
}
