#include <mpi.h>
#include <omp.h>
#include <stdio.h>
#include "../common/common.h"
#include "../tree/tree.h"
#include "../adp/adp.h"


//

#ifdef THREAD_FUNNELED
    #define THREAD_LEVEL MPI_THREAD_FUNNELED
#else
    #define THREAD_LEVEL MPI_THREAD_MULTIPLE
#endif

int main(int argc, char** argv) {
    #if defined (_OPENMP)
        int mpi_provided_thread_level;
        MPI_Init_thread( &argc, &argv, THREAD_LEVEL, &mpi_provided_thread_level);
        if ( mpi_provided_thread_level < THREAD_LEVEL ) 
        {
            switch(THREAD_LEVEL)
            {
                case MPI_THREAD_FUNNELED:
                    printf("a problem arise when asking for MPI_THREAD_FUNNELED level\n");
                    MPI_Finalize();
                    exit( 1 );
                    break;
                case MPI_THREAD_SERIALIZED:
                    printf("a problem arise when asking for MPI_THREAD_SERIALIZED level\n");
                    MPI_Finalize();
                    exit( 1 );
                    break;
                case MPI_THREAD_MULTIPLE:
                    printf("a problem arise when asking for MPI_THREAD_MULTIPLE level\n");
                    MPI_Finalize();
                    exit( 1 );
                    break;
            }
        }
    #else
        MPI_Init(NULL, NULL);
    #endif

    char processor_name[MPI_MAX_PROCESSOR_NAME];
    int name_len;
    MPI_Get_processor_name(processor_name, &name_len);
	
	global_context_t ctx;

	ctx.mpi_communicator = MPI_COMM_WORLD;
	get_context(&ctx);

    #if defined (_OPENMP)
        mpi_printf(&ctx,"Running Hybrid (Openmp + MPI) code\n");
    #else
        mpi_printf(&ctx,"Running pure MPI code\n");
    #endif

    #if defined (THREAD_FUNNELED)
        mpi_printf(&ctx,"/!\\ Code built with MPI_THREAD_FUNNELED level\n");
    #else
        mpi_printf(&ctx,"/!\\ Code built with MPI_THREAD_MULTIPLE level\n");
    #endif

	/*
	 * Mock reading some files, one for each processor
	 */

	int d = 5;
	
	float_t* data;

	/*
	 * Generate a random matrix of lenght of some kind
	 */
	if(ctx.mpi_rank == 0)
	{
		simulate_master_read_and_scatter(5, 1000000, &ctx);	
	}
	else
	{
		simulate_master_read_and_scatter(0, 0, &ctx);	
	}

	//free(data);
	free_context(&ctx);
    MPI_Finalize();
}



void simulate_master_read_and_scatter(int dims, size_t n, global_context_t *ctx) 
{
    float_t *data;
    TIME_DEF
    double elapsed_time;

    if (ctx->mpi_rank == 0) 
    {
        //data = read_data_file(ctx, "../norm_data/50_blobs_more_var.npy", MY_TRUE);
        //ctx->dims = 2;
        //data = read_data_file(ctx, "../norm_data/50_blobs.npy", MY_TRUE);
        // std_g0163178_Me14_091_0000
    
        // 190M points
        // std_g2980844_091_0000
        data = read_data_file(ctx,"../norm_data/std_g2980844_091_0000",MY_TRUE);
        
        /* 1M points ca.*/
        // data = read_data_file(ctx,"../norm_data/std_LR_091_0001",MY_TRUE);

        /* BOX */
        // data = read_data_file(ctx,"../norm_data/std_Box_256_30_092_0000",MY_TRUE);

        /* 8M points */
        
        // data = read_data_file(ctx,"../norm_data/std_g0144846_Me14_091_0001",MY_TRUE);

        //88M 
        //data = read_data_file(ctx,"../norm_data/std_g5503149_091_0000",MY_TRUE);

        //
        //34 M
        //data = read_data_file(ctx,"../norm_data/std_g1212639_091_0001",MY_TRUE);
        ctx->dims = 5;

        //ctx -> n_points = 5 * 100000;
        ctx->n_points = ctx->n_points / ctx->dims;
        //ctx->n_points = (ctx->n_points * 5) / 10;
        // ctx -> n_points = ctx -> world_size * 1000;

        //ctx -> n_points = 10000000 * ctx -> world_size;
        //generate_random_matrix(&data, ctx -> dims, ctx -> n_points, ctx);
        //mpi_printf(ctx, "Read %lu points in %u dims\n", ctx->n_points, ctx->dims);
    }
    //MPI_DB_PRINT("[MASTER] Reading file and scattering\n");
    
    /* communicate the total number of points*/
    MPI_Bcast(&(ctx->dims), 1, MPI_UINT32_T, 0, ctx->mpi_communicator);
    MPI_Bcast(&(ctx->n_points), 1, MPI_UINT64_T, 0, ctx->mpi_communicator);

    /* compute the number of elements to recieve for each processor */
    int *send_counts = (int *)MY_MALLOC(ctx->world_size * sizeof(int));
    int *displacements = (int *)MY_MALLOC(ctx->world_size * sizeof(int));

    displacements[0] = 0;
    send_counts[0] = ctx->n_points / ctx->world_size;
    send_counts[0] += (ctx->n_points % ctx->world_size) > 0 ? 1 : 0;
    send_counts[0] = send_counts[0] * ctx->dims;

    for (int p = 1; p < ctx->world_size; ++p) 
    {
        send_counts[p] = (ctx->n_points / ctx->world_size);
        send_counts[p] += (ctx->n_points % ctx->world_size) > p ? 1 : 0;
        send_counts[p] = send_counts[p] * ctx->dims;
        displacements[p] = displacements[p - 1] + send_counts[p - 1];
    }


    ctx->local_n_points = send_counts[ctx->mpi_rank] / ctx->dims;

    float_t *pvt_data = (float_t *)MY_MALLOC(send_counts[ctx->mpi_rank] * sizeof(float_t));

    MPI_Scatterv(data, send_counts, displacements, MPI_MY_FLOAT, pvt_data, send_counts[ctx->mpi_rank], MPI_MY_FLOAT, 0, ctx->mpi_communicator);

    ctx->local_data = pvt_data;

    int k_local = 20;
    int k_global = 20;

    uint64_t *global_bin_counts_int = (uint64_t *)MY_MALLOC(k_global * sizeof(uint64_t));

    pointset_t original_ps;
    original_ps.data = ctx->local_data;
    original_ps.dims = ctx->dims;
    original_ps.n_points = ctx->local_n_points;
    original_ps.lb_box = (float_t*)MY_MALLOC(ctx -> dims * sizeof(float_t));
    original_ps.ub_box = (float_t*)MY_MALLOC(ctx -> dims * sizeof(float_t));

    float_t tol = 0.002;

    top_kdtree_t tree;
    TIME_START;
    top_tree_init(ctx, &tree);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Initializing global kdtree", elapsed_time);

    TIME_START;
    build_top_kdtree(ctx, &original_ps, &tree, k_global, tol);
    exchange_points(ctx, &tree);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Top kdtree build and domain decomposition", elapsed_time);
    //test_the_idea(ctx);

    TIME_START;
    kdtree_v2 local_tree;
    kdtree_v2_init( &local_tree, ctx -> local_data, ctx -> local_n_points, (unsigned int)ctx -> dims);
    int k = 300;
    //int k = 30;

    datapoint_info_t* dp_info = (datapoint_info_t*)MY_MALLOC(ctx -> local_n_points * sizeof(datapoint_info_t));            
    /* initialize, to cope with valgrind */
    for(uint64_t i = 0; i < ctx -> local_n_points; ++i)
    {
        dp_info[i].ngbh.data = NULL;
        dp_info[i].ngbh.N = 0;
        dp_info[i].ngbh.count = 0;
        dp_info[i].g = 0.f;
        dp_info[i].log_rho = 0.f;
        dp_info[i].log_rho_c = 0.f;
        dp_info[i].log_rho_err = 0.f;
        dp_info[i].array_idx = -1;
        dp_info[i].kstar = -1;
        dp_info[i].is_center = -1;
        dp_info[i].cluster_idx = -1;
    }

    build_local_tree(ctx, &local_tree);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Local trees init and build", elapsed_time);

    TIME_START;
    MPI_DB_PRINT("----- Performing ngbh search -----\n");
    MPI_Barrier(ctx -> mpi_communicator);

    mpi_ngbh_search(ctx, dp_info, &tree, &local_tree, ctx -> local_data, k);

    MPI_Barrier(ctx -> mpi_communicator);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Total time for all knn search", elapsed_time)



    TIME_START;
    //float_t id = id_estimate(ctx, dp_info, ctx -> local_n_points, 0.9, MY_FALSE);
    float_t id = compute_ID_two_NN_ML(ctx, dp_info, ctx -> local_n_points, MY_FALSE);
    elapsed_time = TIME_STOP;
    LOG_WRITE("ID estimate", elapsed_time)

    MPI_DB_PRINT("ID %lf \n",id);

    TIME_START;

    float_t  z = 2;

    ctx -> local_datapoints = dp_info;
    //compute_density_kstarnn_rma(ctx, id, MY_FALSE);
    compute_density_kstarnn_rma_v2(ctx, id, MY_FALSE);
    compute_correction(ctx, z);
    elapsed_time = TIME_STOP;
    LOG_WRITE("Density estimate", elapsed_time)

    TIME_START;
    clusters_t clusters = Heuristic1(ctx, MY_FALSE);
    elapsed_time = TIME_STOP;
    LOG_WRITE("H1", elapsed_time)

    TIME_START;
    clusters_allocate(&clusters, 1);
    Heuristic2(ctx, &clusters);
    elapsed_time = TIME_STOP;
    LOG_WRITE("H2", elapsed_time)


    TIME_START;
    int halo = 0;
    Heuristic3(ctx, &clusters, z, halo);
    elapsed_time = TIME_STOP;
    LOG_WRITE("H3", elapsed_time)


    /* find density */ 
    #if defined (WRITE_NGBH)
        ordered_data_to_file(ctx);
    #endif

    /*
    free(foreign_dp_info);
    */
    
    
    top_tree_free(ctx, &tree);
    kdtree_v2_free(&local_tree);

    free(send_counts);
    free(displacements);
    //free(dp_info);

    if (ctx->mpi_rank == 0) free(data);

    original_ps.data = NULL;
    free_pointset(&original_ps);
    free(global_bin_counts_int);
}
