Skip to content
Snippets Groups Projects
Select Git revision
  • 64fd89069514ed2e6ef6e62bd3c877a62d75bf5a
  • master default
  • rocky-linux-9
  • rocky-linux-8
  • debian11
  • v1.1.2
  • v1.1.1
  • v1.1.0
  • v1.0.0
  • v1.0
  • beta1
11 results

PlainClient.cpp

Blame
  • tree.c 91.78 KiB
    /*
     * Implementation of a distributed memory  kd-tree
     * The idea is to have a top level domain decomposition with a shallow shared
     * top level tree between computational nodes.
     *
     * Then each domain has a different set of points to work on separately
     * the top tree serves as a map to know later on in which processor ask for
     * neighbors
     */
    #include "tree.h"
    #include "heap.h"
    #include "kdtreeV2.h"
    #include "mpi.h"
    #include <math.h>
    #include <stdint.h>
    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    #include <omp.h>
    #include <sys/sysinfo.h>
    
    //#define WRITE_NGBH
    //#define WRITE_TOP_NODES
    #define WRITE_DENSITY
    
    /* 
     * Maximum bytes to send with a single mpi send/recv, used 
     * while communicating results of ngbh search
     */
    
    /* Maximum allowed is 4GB */
    //#define MAX_MSG_SIZE 4294967296
    
    /* Used slices of 10 mb */
    #define MAX_MSG_SIZE 10000000 
    
    #ifdef USE_FLOAT32
    #define MPI_MY_FLOAT MPI_FLOAT
    #else
    #define MPI_MY_FLOAT MPI_DOUBLE
    #endif
    
    #define HERE printf("%d reached line %d\n", ctx -> mpi_rank, __LINE__);
    
    #define I_AM_MASTER ctx->mpi_rank == 0
    
    #define TOP_TREE_RCH 1
    #define TOP_TREE_LCH 0
    #define NO_CHILD -1
    
    unsigned int data_dims;
    
    
    int cmp_float_t(const void* a, const void* b)
    {
        float_t aa = *((float_t*)a);
        float_t bb = *((float_t*)b);
        return  (aa > bb) - (aa < bb);
    }
    
    
    float_t *read_data_file(global_context_t *ctx, const char *fname,
                            const int file_in_float32) 
    {
    
        FILE *f = fopen(fname, "r");
        if (!f) 
        {
            printf("Nope\n");
            exit(1);
        }
        fseek(f, 0, SEEK_END);
        size_t n = ftell(f);
        rewind(f);
    
        int InputFloatSize = file_in_float32 ? 4 : 8;
    
        n = n / (InputFloatSize);
    
        float_t *data = (float_t *)malloc(n * sizeof(float_t));
    
        if (file_in_float32) 
        {
            float *df = (float *)malloc(n * sizeof(float));
            size_t fff = fread(df, sizeof(float), n, f);
            mpi_printf(ctx, "Read %luB\n", fff);
            fclose(f);
    
            for (uint64_t i = 0; i < n; ++i) data[i] = (float_t)(df[i]);
    
            free(df);
        } 
        else 
        {
            double *df = (double *)malloc(n * sizeof(double));
            size_t fff = fread(df, sizeof(double), n, f);
            mpi_printf(ctx, "Read %luB\n", fff);
            fclose(f);
    
            for (uint64_t i = 0; i < n; ++i) data[i] = (float_t)(df[i]);
    
            free(df);
        }
        ctx->n_points = n;
        return data;
    }
    
    /* quickselect for an element along a dimension */
    
    void swap_data_element(float_t *a, float_t *b, int vec_len) {
        float_t tmp;
        for (int i = 0; i < vec_len; ++i) 
        {
            tmp = a[i];
            a[i] = b[i];
            b[i] = tmp;
        }
    }
    
    int compare_data_element(float_t *a, float_t *b, int compare_dim) {
        return -((a[compare_dim] - b[compare_dim]) > 0.) + ((a[compare_dim] - b[compare_dim]) < 0.);
    }
    
    int partition_data_element(float_t *array, int vec_len, int compare_dim,
                               int left, int right, int pivot_index) 
    {
        int store_index = left;
        int i;
        /* Move pivot to end */
        swap_data_element(array + pivot_index * vec_len, array + right * vec_len, vec_len);
        for (i = left; i < right; ++i) 
        {
            // if(compare_data_element(array + i*vec_len, array + pivot_index*vec_len,
            // compare_dim ) >= 0){
            if (array[i * vec_len + compare_dim] < array[right * vec_len + compare_dim]) 
            {
                swap_data_element(array + store_index * vec_len, array + i * vec_len, vec_len);
                store_index += 1;
            }
        }
        /* Move pivot to its final place */
        swap_data_element(array + (store_index)*vec_len, array + right * vec_len, vec_len);
    
        return store_index;
    }
    
    int qselect_data_element(float_t *array, int vec_len, int compare_dim, int left, int right, int n) 
    {
        int pivot_index;
        if (left == right) 
        {
            return left;
        }
        pivot_index = left; // + (rand() % (right-left + 1)); /* random int left <= x <= right */
        pivot_index = partition_data_element(array, vec_len, compare_dim, left, right, pivot_index);
        /* The pivot is in its final sorted position */
        if (n == pivot_index) 
        {
            return pivot_index;
        } 
        else if (n < pivot_index) 
        {
            return qselect_data_element(array, vec_len, compare_dim, left, pivot_index - 1, n);
        } 
        else 
        {
            return qselect_data_element(array, vec_len, compare_dim, pivot_index + 1, right, n);
        }
    }
    
    int quickselect_data_element(float_t *array, int vec_len, int array_size, int compare_dim, int k) 
    {
        return qselect_data_element(array, vec_len, compare_dim, 0, array_size - 1, k - 1);
    }
    
    int CMP_DIM;
    int compare_data_element_sort(const void *a, const void *b) {
        float_t aa = *((float_t *)a + CMP_DIM);
        float_t bb = *((float_t *)b + CMP_DIM);
        return ((aa - bb) > 0.) - ((aa - bb) < 0.);
    }
    
    void compute_bounding_box(global_context_t *ctx) {
        ctx->lb_box = (float_t *)malloc(ctx->dims * sizeof(float_t));
        ctx->ub_box = (float_t *)malloc(ctx->dims * sizeof(float_t));
    
        for (size_t d = 0; d < ctx->dims; ++d) {
        ctx->lb_box[d] = 99999999.;
        ctx->ub_box[d] = -99999999.;
        }
    
        #define local_data ctx->local_data
        #define lb ctx->lb_box
        #define ub ctx->ub_box
    
        /* compute minimum and maximum for each dimensions, store them in local bb */
        /* each processor on its own */
        for (size_t i = 0; i < ctx->local_n_points; ++i) {
        for (size_t d = 0; d < ctx->dims; ++d) {
          lb[d] = MIN(local_data[i * ctx->dims + d], lb[d]);
          ub[d] = MAX(local_data[i * ctx->dims + d], ub[d]);
        }
        }
    
        /* Reduce to obtain bb */
        /*
        MPI_Allreduce(  const void *sendbuf,
                                      void *recvbuf,
                                      int count,
                                      MPI_Datatype datatype,
                                      MPI_Op op,
                                      MPI_Comm comm)
        */
    
        /*get the bounding box */
    
        MPI_Allreduce(MPI_IN_PLACE, lb, ctx->dims, MPI_MY_FLOAT, MPI_MIN, ctx->mpi_communicator);
        MPI_Allreduce(MPI_IN_PLACE, ub, ctx->dims, MPI_MY_FLOAT, MPI_MAX, ctx->mpi_communicator);
    
        /*
        DB_PRINT("[RANK %d]:", ctx -> mpi_rank);
        for(size_t d = 0; d < ctx -> dims; ++d)
        {
              DB_PRINT("%lf ", ctx -> ub_box[d]);
        }
        DB_PRINT("\n");
        */
    
        /*
        MPI_DB_PRINT("[BOUNDING BOX]: ");
        for(size_t d = 0; d < ctx -> dims; ++d) MPI_DB_PRINT("d%d:[%lf, %lf] ",(int)d,
        lb[d], ub[d]); MPI_DB_PRINT("\n");
        */
    
        #undef local_data
        #undef lb
        #undef ub
    }
    
    /* i want a queue to enqueue the partitions to deal with */
    void enqueue_partition(partition_queue_t *queue, partition_t p) 
    {
        if (queue->count == queue->_capacity) 
        {
            queue->_capacity = queue->_capacity * 1.10;
            queue->data = realloc(queue->data, queue->_capacity);
        }
        /* insert point */
        memmove(queue->data + 1, queue->data, queue->count * sizeof(partition_t));
        queue->data[0] = p;
        queue->count++;
    }
    
    partition_t dequeue_partition(partition_queue_t *queue) 
    {
      return queue->data[--(queue->count)];
    }
    
    void compute_medians_and_check(global_context_t *ctx, float_t *data) {
        float_t prop = 0.5;
        int k = (int)(ctx->local_n_points * prop);
        int d = 1;
    
        /*quick select on a particular dimension */
        CMP_DIM = d;
        int kk = (k - 1) * ctx->dims;
    
        int count = 0;
        // idx = idx - 1;
        //
        int aaa = quickselect_data_element(ctx->local_data, (int)(ctx->dims), (int)(ctx->local_n_points), d, k);
        /*
        * sanity check
        * check if the median found in each node is
        * a median
        */
    
        float_t *medians_rcv = (float_t *)malloc(ctx->dims * ctx->world_size * sizeof(float_t));
    
        /*
        * MPI_Allgather(     const void *sendbuf,
        *                     int sendcount,
        *                     MPI_Datatype sendtype,
        *                     void *recvbuf,
        *                     int recvcount,
        *                     MPI_Datatype recvtype,
        *                     MPI_Comm comm)
        */
    
        /* Exchange medians */
    
        MPI_Allgather(ctx->local_data + kk, ctx->dims, MPI_MY_FLOAT, medians_rcv, ctx->dims, MPI_MY_FLOAT, ctx->mpi_communicator);
    
        /* sort medians on each node */
    
        CMP_DIM = d;
        qsort(medians_rcv, ctx->world_size, ctx->dims * sizeof(float_t), compare_data_element_sort);
    
        /*
        * Evaluate goodness of the median on master which has whole dataset
        */
    
        if (ctx->mpi_rank == 0) {
        int count = 0;
        int idx = (int)(prop * (ctx->world_size));
        // idx = idx - 1;
        for (int i = 0; i < ctx->n_points; ++i) 
        {
            count += data[i * ctx->dims + d] <= medians_rcv[idx * ctx->dims + d];
        }
        mpi_printf(ctx, "Choosing %lf percentile on dimension %d: empirical prop %lf\n", prop, d, (float_t)count / (float_t)(ctx->n_points));
        }
        free(medians_rcv);
    }
    
    float_t check_pc_pointset_parallel(global_context_t *ctx, pointset_t *ps, guess_t g, int d, float_t prop) {
        /*
         * ONLY FOR TEST PURPOSES
         * gather on master all data
         * perform the count on master
         */
        int pvt_count = 0;
        for (int i = 0; i < ps->n_points; ++i) 
        {
            pvt_count += ps->data[i * ps->dims + d] <= g.x_guess;
        }
    
        int pvt_n_and_tot[2] = {pvt_count, ps->n_points};
        int tot_count[2];
        MPI_Allreduce(pvt_n_and_tot, tot_count, 2, MPI_INT, MPI_SUM, ctx->mpi_communicator);
    
        float_t ep = (float_t)tot_count[0] / (float_t)(tot_count[1]);
        /*
        mpi_printf(ctx,"[PS TEST PARALLEL]: ");
        mpi_printf(ctx,"Condsidering %d points, searching for %lf percentile on
        dimension %d: empirical measure %lf\n",tot_count[1],prop, d, ep);
        */
        return ep;
    }
    
    void compute_bounding_box_pointset(global_context_t *ctx, pointset_t *ps) {
    
        for (size_t d = 0; d < ps->dims; ++d)
        {
            ps->lb_box[d] = 99999999.;
            ps->ub_box[d] = -99999999.;
        }
    
        #define local_data ps->data
        #define lb ps->lb_box
        #define ub ps->ub_box
    
        /* compute minimum and maximum for each dimensions, store them in local bb */
        /* each processor on its own */
    
        /* TODO: reduction using omp directive */
    
        for (size_t i = 0; i < ps->n_points; ++i) 
        {
            for (size_t d = 0; d < ps->dims; ++d) 
            {
                lb[d] = MIN(local_data[i * ps->dims + d], lb[d]);
                ub[d] = MAX(local_data[i * ps->dims + d], ub[d]);
            }
          }
    
        /* Reduce to obtain bb */
        /*
        MPI_Allreduce(  const void *sendbuf,
                                      void *recvbuf,
                                      int count,
                                      MPI_Datatype datatype,
                                      MPI_Op op,
                                      MPI_Comm comm)
        */
    
        /*get the bounding box */
    
        MPI_Allreduce(MPI_IN_PLACE, lb, ps->dims, MPI_MY_FLOAT, MPI_MIN, ctx->mpi_communicator);
        MPI_Allreduce(MPI_IN_PLACE, ub, ps->dims, MPI_MY_FLOAT, MPI_MAX, ctx->mpi_communicator);
    
        /*
        DB_PRINT("[RANK %d]:", ctx -> mpi_rank);
        for(size_t d = 0; d < ctx -> dims; ++d)
        {
              DB_PRINT("%lf ", ctx -> ub_box[d]);
        }
        DB_PRINT("\n");
        */
    
        /*
        MPI_DB_PRINT("[PS BOUNDING BOX]: ");
        for(size_t d = 0; d < ps -> dims; ++d) MPI_DB_PRINT("d%d:[%lf, %lf] ",(int)d,
        lb[d], ub[d]); MPI_DB_PRINT("\n");
        */
    
    
    
        #undef local_data
        #undef lb
        #undef ub
    }
    
    
    guess_t retrieve_guess_pure(global_context_t *ctx, pointset_t *ps,
                            uint64_t *global_bin_counts, 
                            int k_global, int d, float_t pc)
    {
    
        /*
        * retrieving the best median guess from pure binning
        */
    
        float_t total_count = 0.;
        for (int i = 0; i < k_global; ++i) total_count += (float_t)global_bin_counts[i];
    
        /*
        MPI_DB_PRINT("[ ");
        for(int i = 0; i < k_global; ++i)
        {
              MPI_DB_PRINT("%lu %lf --- ", global_bin_counts[i],
        (float_t)global_bin_counts[i]/(float_t)total_count);
        }
        MPI_DB_PRINT("\n");
        */
    
        float_t cumulative_count = 0;
        int idx = 0;
        while ((cumulative_count + (float_t)global_bin_counts[idx]) / total_count < pc) 
        {
            cumulative_count += (float_t)global_bin_counts[idx];
            idx++;
        }
        /* find best spot in the bin */
        float_t box_lb = ps->lb_box[d];
        float_t box_ub = ps->ub_box[d];
        float_t box_width = box_ub - box_lb;
        float_t global_bin_width = box_width / (float_t)k_global;
    
        float_t x0 = box_lb + (global_bin_width * (idx));
        float_t x1 = box_lb + (global_bin_width * (idx + 1));
    
        float_t y0 = (cumulative_count) / total_count;
        float_t y1 = (cumulative_count + global_bin_counts[idx]) / total_count;
    
        float_t x_guess = (pc - y0) / (y1 - y0) * (x1 - x0) + x0;
    
            
        /*
        MPI_DB_PRINT("[MASTER] best guess @ %lf is %lf on bin %d on dimension %d --- x0 %lf x1 %lf y0 %lf y1 %lf\n",pc, x_guess,idx, d, x0, x1, y0, y1);
        */
        
    
    
        guess_t g = {.bin_idx = idx, .x_guess = x_guess};
        return g;
    }
    
    
    void global_binning_check(global_context_t *ctx, float_t *data, int d, int k) 
    {
        /*
        * sanity check
        * find if global bins are somehow similar to acutal binning on master
        */
    
        if (I_AM_MASTER) 
        {
            int *counts = (int *)malloc(k * sizeof(int));
            for (int bin_idx = 0; bin_idx < k; ++bin_idx) counts[bin_idx] = 0;
    
            float_t box_lb = ctx->lb_box[d];
            float_t box_ub = ctx->ub_box[d];
            float_t box_width = box_ub - box_lb;
            float_t bin_width = box_width / (float_t)k;
    
            for (int i = 0; i < ctx->n_points; ++i) 
            {
                int bin_idx = (int)((data[i * ctx->dims + d] - box_lb) / bin_width);
                if (bin_idx < k) counts[bin_idx]++;
                // counts[bin_idx]++
            }
            int cc = 0;
    
            /*
            MPI_DB_PRINT("Actual bins: [");
            for(int bin_idx = 0; bin_idx < k; ++bin_idx)
            {
                    MPI_DB_PRINT("%d ", counts[bin_idx]);
                    cc += counts[bin_idx];
            }
            MPI_DB_PRINT("] tot %d\n",cc);
            */
    
            free(counts);
        }
    }
    
    void compute_pure_global_binning(global_context_t *ctx, pointset_t *ps,
                                     uint64_t *global_bin_counts, int k_global,
                                     int d) 
    {
        /* compute binning of data along dimension d */
        uint64_t *local_bin_count = (uint64_t *)malloc(k_global * sizeof(uint64_t));
        for (size_t k = 0; k < k_global; ++k) 
        {
            local_bin_count[k] = 0;
            global_bin_counts[k] = 0;
        }
    
        /*
        MPI_DB_PRINT("[PS BOUNDING BOX %d]: ", ctx -> mpi_rank);
        for(size_t d = 0; d < ps -> dims; ++d) MPI_DB_PRINT("d%d:[%lf, %lf] ",(int)d, ps -> lb_box[d], ps -> ub_box[d]); MPI_DB_PRINT("\n");
        MPI_DB_PRINT("\n");
        */
    
        float_t bin_w = (ps-> ub_box[d] - ps->lb_box[d]) / (float_t)k_global;
    
        #pragma omp parallel for
        for (size_t i = 0; i < ps->n_points; ++i) 
        {
            float_t p = ps->data[i * ps->dims + d];
            /* to prevent the border point in the box to have bin_idx == k_global causing invalid memory access */
            int bin_idx = MIN((int)((p - ps->lb_box[d]) / bin_w), k_global - 1);
            
            #pragma omp atomic update
            local_bin_count[bin_idx]++;
        }
    
        MPI_Allreduce(local_bin_count, global_bin_counts, k_global, MPI_UNSIGNED_LONG, MPI_SUM, ctx->mpi_communicator);
        free(local_bin_count);
    }
    
    int partition_data_around_value(float_t *array, int vec_len, int compare_dim,
                                    int left, int right, float_t pivot_value) 
    {
        /*
        * returns the number of elements less than the pivot
        */
        int store_index = left;
        int i;
        /* Move pivot to end */
        for (i = left; i < right; ++i) 
        {
            // if(compare_data_element(array + i*vec_len, array + pivot_index*vec_len, compare_dim ) >= 0){
            if (array[i * vec_len + compare_dim] < pivot_value) 
            {
                swap_data_element(array + store_index * vec_len, array + i * vec_len, vec_len);
                store_index += 1;
            }
        }
        /* Move pivot to its final place */
        // swap_data_element(array + (store_index)*vec_len , array + right*vec_len,
        // vec_len);
    
        return store_index; // maybe, update, it works :)
    }
    
    guess_t refine_pure_binning(global_context_t *ctx, pointset_t *ps,
                                guess_t best_guess, uint64_t *global_bin_count,
                                int k_global, int d, float_t f, float_t tolerance)
    {
        /* refine process from obtained binning */
        if (fabs(best_guess.ep - f) < tolerance) 
        {
            /*
            MPI_DB_PRINT("[MASTER] No need to refine, finishing\n");
            */
            return best_guess;
        }
        float_t total_count = 0;
        float_t starting_cumulative = 0;
    
        for (int i = 0; i < best_guess.bin_idx; ++i) starting_cumulative += global_bin_count[i];
        for (int i = 0; i < k_global; ++i) total_count += global_bin_count[i];
    
        float_t bin_w = (ps->ub_box[d] - ps->lb_box[d]) / k_global;
        float_t bin_lb = ps->lb_box[d] + (bin_w * (best_guess.bin_idx));
        float_t bin_ub = ps->lb_box[d] + (bin_w * (best_guess.bin_idx + 1));
    
        uint64_t *tmp_global_bins = (uint64_t *)malloc(sizeof(uint64_t) * k_global);
        for (int i = 0; i < k_global; ++i) tmp_global_bins[i] = global_bin_count[i];
    
        /*
        MPI_DB_PRINT("STARTING REFINE global bins: ");
        for(int i = 0; i < k_global; ++i)
        {
              MPI_DB_PRINT("%lf ", global_bin_count[i]);
        }
        MPI_DB_PRINT("\n");
        */
    
        guess_t g;
        while (fabs(best_guess.ep - f) > tolerance) {
            /* compute the target */
            float_t ff, b0, b1;
            ff = -1;
            b0 = starting_cumulative;
            b1 = tmp_global_bins[best_guess.bin_idx];
            ff = (f * total_count - b0) / ((float_t)tmp_global_bins[best_guess.bin_idx]);
    
            /*
             * generate a partset of points in the bin considered
             * each one has to partition its dataset according to the
             * fact that points on dimension d has to be in the bin
             *
             * then make into in place alg for now, copy data in another pointer
             * will be done in place
             * */
    
            
            /*
            MPI_DB_PRINT("---- ---- ----\n");
            MPI_DB_PRINT("[MASTER] Refining on bin %d lb %lf ub %lf starting c %lf %lf\n", 
                    best_guess.bin_idx, bin_lb, bin_ub, starting_cumulative/total_count,
                    (tmp_global_bins[best_guess.bin_idx] + starting_cumulative)/total_count);
            */
        
    
            for (int i = 0; i < k_global; ++i)  tmp_global_bins[i] = 0;
    
            pointset_t tmp_ps;
    
            int end_idx = partition_data_around_value(ps->data, (int)ps->dims, d, 0, (int)ps->n_points, bin_ub);
            int start_idx = partition_data_around_value(ps->data, (int)ps->dims, d, 0,end_idx, bin_lb);
    
            tmp_ps.n_points = end_idx - start_idx;
            tmp_ps.data = ps->data + start_idx * ps->dims;
            tmp_ps.dims = ps->dims;
            tmp_ps.lb_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
            tmp_ps.ub_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
    
            compute_bounding_box_pointset(ctx, &tmp_ps);
    
            /*
            MPI_DB_PRINT("[MASTER] searching for %lf of the bin considered\n",ff);
            */
    
            // DB_PRINT("%lu\n",tmp_ps.n_points );
            MPI_Barrier(ctx->mpi_communicator);
            compute_pure_global_binning(ctx, &tmp_ps, tmp_global_bins, k_global, d);
    
            /* sum to global bins */
            // for(int i = 0; i < k_global; ++i) tmp_global_bins[i] +=
            // starting_cumulative;
    
            best_guess = retrieve_guess_pure(ctx, &tmp_ps, tmp_global_bins, k_global, d, ff);
    
            best_guess.ep = check_pc_pointset_parallel(ctx, ps, best_guess, d, f);
            // ep = check_pc_pointset_parallel(ctx, &tmp_ps, best_guess, d, f);
    
            bin_w = (tmp_ps.ub_box[d] - tmp_ps.lb_box[d]) / k_global;
            bin_lb = tmp_ps.lb_box[d] + (bin_w * (best_guess.bin_idx));
            bin_ub = tmp_ps.lb_box[d] + (bin_w * (best_guess.bin_idx + 1));
    
            for (int i = 0; i < best_guess.bin_idx; ++i) starting_cumulative += tmp_global_bins[i];
    
            // free(tmp_ps.data);
            free(tmp_ps.lb_box);
            free(tmp_ps.ub_box);
        }
    
        /*
        MPI_DB_PRINT("SUCCESS!!! \n");
        */
    
        free(tmp_global_bins);
    
        return best_guess;
    }
    
    void init_queue(partition_queue_t *pq) 
    {
        pq->count = 0;
        pq->_capacity = 1000;
        pq->data = (partition_t *)malloc(pq->_capacity * sizeof(partition_t));
    }
    
    void free_queue(partition_queue_t *pq) { free(pq->data); }
    
    void get_pointset_from_partition(pointset_t *ps, partition_t *part) 
    {
        ps->n_points = part->n_points;
        ps->data      = part->base_ptr;
        ps->n_points = part->n_points;
    }
    
    guess_t compute_median_pure_binning(global_context_t *ctx, pointset_t *ps, float_t fraction, int selected_dim, int n_bins, float_t tolerance)
    {
        int best_bin_idx;
        float_t ep;
    
        uint64_t *global_bin_counts_int = (uint64_t *)malloc(n_bins * sizeof(uint64_t));
    
        compute_bounding_box_pointset(ctx, ps);
        compute_pure_global_binning(ctx, ps, global_bin_counts_int, n_bins, selected_dim);
        guess_t g = retrieve_guess_pure(ctx, ps, global_bin_counts_int, n_bins, selected_dim, fraction);
        // check_pc_pointset(ctx, ps, best_guess, d, f);
        g.ep = check_pc_pointset_parallel(ctx, ps, g, selected_dim, fraction);
        g = refine_pure_binning(ctx, ps, g, global_bin_counts_int, n_bins, selected_dim, fraction, tolerance);
        free(global_bin_counts_int);
        return g;
    }
    
    int compute_n_nodes(int n)
    {
        if(n == 1) return 1;
        int nl = n/2;
        int nr = n - nl;
        return 1 + compute_n_nodes(nl) + compute_n_nodes(nr);
    }
    
    void top_tree_init(global_context_t *ctx, top_kdtree_t *tree) 
    {
        /* we want procs leaves */
        int l = (int)(ceil(log2((float_t)ctx -> world_size)));    
        int tree_nodes = (1 << (l + 1)) - 1;
        //int tree_nodes = compute_n_nodes(ctx -> world_size);    
        //MPI_DB_PRINT("Tree nodes %d %d %d %d\n", ctx -> world_size,l, tree_nodes, compute_n_nodes(ctx -> world_size));
        tree->_nodes      = (top_kdtree_node_t*)malloc(tree_nodes * sizeof(top_kdtree_node_t));
        for(int i = 0; i < tree_nodes; ++i)
        {
            tree -> _nodes[i].lch = NULL;
            tree -> _nodes[i].rch = NULL;
            tree -> _nodes[i].parent = NULL;
            tree -> _nodes[i].owner = -1;
            tree -> _nodes[i].n_points = 0;
            tree -> _nodes[i].split_dim = -1;
            tree -> _nodes[i].split_val = 0.f;
            tree -> _nodes[i].lb_node_box = NULL;
            tree -> _nodes[i].ub_node_box = NULL;
    
        }
        tree->_capacity = tree_nodes;
        tree->dims         = ctx->dims;
        tree->count     = 0;
        return;
    }
    
    void top_tree_free(global_context_t *ctx, top_kdtree_t *tree) 
    {
        for(int i = 0; i < tree -> count; ++i)
        {
            if(tree -> _nodes[i].lb_node_box) free(tree -> _nodes[i].lb_node_box);
            if(tree -> _nodes[i].ub_node_box) free(tree -> _nodes[i].ub_node_box);
        }
        free(tree->_nodes);
        return;
    }
    
    top_kdtree_node_t* top_tree_generate_node(global_context_t* ctx, top_kdtree_t* tree)
    {
        top_kdtree_node_t* ptr = tree -> _nodes + tree -> count;
        ptr -> lch = NULL;
        ptr -> rch = NULL;
        ptr -> parent = NULL;
        ptr -> lb_node_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
        ptr -> ub_node_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
        ptr -> owner        = -1;
        ptr -> split_dim   = 0.;
        ++tree -> count;
        return ptr;
     
    }
    
    void tree_print(global_context_t* ctx, top_kdtree_node_t* root)
    {
        MPI_DB_PRINT("Node %p: \n\tsplit_dim %d \n\tsplit_val %lf", root, root -> split_dim, root -> split_val);
        MPI_DB_PRINT("\n\tparent %p", root -> parent);
        MPI_DB_PRINT("\n\towner  %d", root -> owner);
        MPI_DB_PRINT("\n\tbox");
        MPI_DB_PRINT("\n\tlch %p", root -> lch);
        MPI_DB_PRINT("\n\trch %p\n", root -> rch);
        for(size_t d = 0; d < ctx -> dims; ++d) MPI_DB_PRINT("\n\t  d%d:[%lf, %lf]",(int)d, root -> lb_node_box[d], root -> ub_node_box[d]); 
        MPI_DB_PRINT("\n");
        if(root -> lch) tree_print(ctx, root -> lch);
        if(root -> rch) tree_print(ctx, root -> rch);
    }
    void _recursive_nodes_to_file(global_context_t* ctx, FILE* nodes_file, top_kdtree_node_t* root, int level)
    {
        fprintf(nodes_file, "%d,", level);
        fprintf(nodes_file, "%d,", root -> owner);
        fprintf(nodes_file, "%d,", root -> split_dim);
        fprintf(nodes_file, "%lf,", root -> split_val);
        for(int i = 0; i < ctx -> dims; ++i)
        {
            fprintf(nodes_file,"%lf,",root -> lb_node_box[i]);
        }
        for(int i = 0; i < ctx -> dims - 1; ++i)
        {
            fprintf(nodes_file,"%lf,",root -> ub_node_box[i]);
        }
        fprintf(nodes_file,"%lf\n",root -> ub_node_box[ctx -> dims - 1]);
        if(root -> lch) _recursive_nodes_to_file(ctx, nodes_file, root -> lch, level + 1);
        if(root -> rch) _recursive_nodes_to_file(ctx, nodes_file, root -> rch, level + 1);
    }
    void write_nodes_to_file( global_context_t* ctx,top_kdtree_t* tree, 
                            const char* nodes_path) 
    {
        FILE* nodes_file  = fopen(nodes_path,"w");
    
        if(!nodes_file) 
        {
            printf("Cannot open hp file\n");
            return;
        }
        _recursive_nodes_to_file(ctx, nodes_file, tree -> root, 0);
        fclose(nodes_file);
    
        
    }
    
    void tree_print_leaves(global_context_t* ctx, top_kdtree_node_t* root)
    {
        if(root -> owner != -1)
        {
            MPI_DB_PRINT("Node %p: \n\tsplit_dim %d \n\tsplit_val %lf", root, root -> split_dim, root -> split_val);
            MPI_DB_PRINT("\n\tparent %p", root -> parent);
            MPI_DB_PRINT("\n\towner  %d", root -> owner);
            MPI_DB_PRINT("\n\tbox");
            MPI_DB_PRINT("\n\tlch %p", root -> lch);
            MPI_DB_PRINT("\n\trch %p\n", root -> rch);
            for(size_t d = 0; d < ctx -> dims; ++d) MPI_DB_PRINT("\n\t  d%d:[%lf, %lf]",(int)d, root -> lb_node_box[d], root -> ub_node_box[d]); 
            MPI_DB_PRINT("\n");
        }
        if(root -> lch) tree_print_leaves(ctx, root -> lch);
        if(root -> rch) tree_print_leaves(ctx, root -> rch);
    }
    
    void build_top_kdtree(global_context_t *ctx, pointset_t *og_pointset, top_kdtree_t *tree, int n_bins, float_t tolerance) 
    {
        size_t tot_n_points = 0;
        MPI_Allreduce(&(og_pointset->n_points), &tot_n_points, 1, MPI_UINT64_T, MPI_SUM, ctx->mpi_communicator);
    
        /*
        MPI_DB_PRINT("[MASTER] Top tree builder invoked\n");
        */
        MPI_DB_PRINT("\n");
        MPI_DB_PRINT("Building top tree on %lu points with %d processors\n", tot_n_points, ctx->world_size);
        MPI_DB_PRINT("\n");
    
        size_t current_partition_n_points = tot_n_points;
        size_t expected_points_per_node = tot_n_points / ctx->world_size;
    
        /* enqueue the two partitions */
    
        compute_bounding_box_pointset(ctx, og_pointset);
    
        partition_queue_t queue;
        init_queue(&queue);
    
        int selected_dim = 0;
        partition_t current_partition = {  .d          = selected_dim,
                                           .base_ptr = og_pointset->data,
                                           .n_points = og_pointset->n_points,
                                           .n_procs = ctx->world_size,
                                           .parent  = NULL,
                                           .lr         = NO_CHILD 
                                        };
    
        enqueue_partition(&queue, current_partition);
        pointset_t current_pointset;
        current_pointset.lb_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
        current_pointset.ub_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
    
        while (queue.count) 
        {
            /*dequeue the partition to process */
            current_partition = dequeue_partition(&queue);
    
            /* generate e pointset for that partition */
    
            get_pointset_from_partition(&current_pointset, &current_partition);
            current_pointset.dims = ctx->dims;
    
            /*generate a tree node */
    
            top_kdtree_node_t* current_node  = top_tree_generate_node(ctx, tree);
            /* insert node */
            
            /*
            MPI_DB_PRINT("Handling partition: \n\tcurrent_node %p, \n\tdim %d, \n\tn_points %d, \n\tstart_proc %d, \n\tn_procs %d, \n\tparent %p\n", 
                    current_node,
                    current_partition.d,
                    current_partition.n_points,
                    current_partition.start_proc,
                    current_partition.n_procs,
                    current_partition.parent);
            */
    
            switch (current_partition.lr) {
                case TOP_TREE_LCH:
                    if(current_partition.parent)
                    {
                        current_node -> parent           = current_partition.parent;
                        current_node -> parent -> lch = current_node;
                        /* compute the box */
                        /*
                         * left child has lb equal to parent
                         * ub equal to parent except for the dim of splitting 
                         */
                        int parent_split_dim = current_node -> parent -> split_dim;
                        float_t parent_hp      = current_node -> parent -> split_val;
    
                        memcpy(current_node -> lb_node_box, current_node -> parent -> lb_node_box, ctx -> dims * sizeof(float_t));
                        memcpy(current_node -> ub_node_box, current_node -> parent -> ub_node_box, ctx -> dims * sizeof(float_t));
    
                        current_node -> ub_node_box[parent_split_dim] = parent_hp;
                    }
                    break;
    
                case TOP_TREE_RCH:
                    if(current_partition.parent)
                    {
                        current_node -> parent           = current_partition.parent;
                        current_node -> parent -> rch = current_node;
    
                        int parent_split_dim = current_node -> parent -> split_dim;
                        float_t parent_hp      = current_node -> parent -> split_val;
    
                        /*
                         * right child has ub equal to parent
                         * lb equal to parent except for the dim of splitting 
                         */
    
                        memcpy(current_node -> lb_node_box, current_node -> parent -> lb_node_box, ctx -> dims * sizeof(float_t));
                        memcpy(current_node -> ub_node_box, current_node -> parent -> ub_node_box, ctx -> dims * sizeof(float_t));
    
                        current_node -> lb_node_box[parent_split_dim] = parent_hp;
                    }
                    break;
                case NO_CHILD:
                    {
                        tree -> root = current_node;
                        memcpy(current_node -> lb_node_box, og_pointset -> lb_box, ctx -> dims * sizeof(float_t));
                        memcpy(current_node -> ub_node_box, og_pointset -> ub_box, ctx -> dims * sizeof(float_t));
                    }
                    break;
            }
    
            current_node -> split_dim = current_partition.d;
            current_node -> parent = current_partition.parent;
            current_node -> lch = NULL;
            current_node -> rch = NULL;
    
    
            /* handle partition */
            if(current_partition.n_procs > 1)
            {
                float_t fraction = (current_partition.n_procs / 2) / (float_t)current_partition.n_procs;
                guess_t g = compute_median_pure_binning(ctx, &current_pointset, fraction, current_partition.d, n_bins, tolerance);
                int pv = partition_data_around_value(current_pointset.data, ctx->dims, current_partition.d, 0, current_pointset.n_points, g.x_guess);
    
                current_node -> split_val = g.x_guess;
    
                size_t points_left = (size_t)pv;
                size_t points_right = current_partition.n_points - points_left;
    
                int procs_left = current_partition.n_procs * fraction;
                int procs_right = current_partition.n_procs - procs_left;
    
    
                /*
                MPI_DB_PRINT("Chosing as guess: %lf, seareching for %lf, obtained %lf\n", g.x_guess, fraction, g.ep);
                MPI_DB_PRINT("-------------------\n\n");
                */
        
    
    
                int next_dimension = (++selected_dim) % (ctx->dims);
                partition_t left_partition = {
                    .n_points     = points_left, 
                    .n_procs     = procs_left,
                    .start_proc = current_partition.start_proc,
                    .parent     = current_node,
                    .lr         = TOP_TREE_LCH,
                    .base_ptr     = current_pointset.data,
                    .d             = next_dimension,
                };
    
                partition_t right_partition = {
                    .n_points     = points_right, 
                    .n_procs     = procs_right,
                    .start_proc = current_partition.start_proc + procs_left,
                    .parent     = current_node,
                    .lr         = TOP_TREE_RCH,
                    .base_ptr     = current_pointset.data + pv * current_pointset.dims,
                    .d             = next_dimension
                };
    
                enqueue_partition(&queue, left_partition);
                enqueue_partition(&queue, right_partition);
            }
            else
            {
                current_node -> owner = current_partition.start_proc;
            }
        }
        tree -> root = tree -> _nodes;
    
        #if defined(WRITE_TOP_NODES)
        MPI_DB_PRINT("Root is %p\n", tree -> root);
            if(I_AM_MASTER)
            {
                //tree_print(ctx, tree -> root);
                write_nodes_to_file(ctx, tree, "bb/top_nodes.csv");
            }
        #endif
    
        
        free(current_pointset.lb_box);
        free(current_pointset.ub_box);
        free_queue(&queue);
    
    }
    
    int compute_point_owner(global_context_t* ctx, top_kdtree_t* tree, float_t* data)
    {
        top_kdtree_node_t* current_node = tree -> root;
        int owner = current_node -> owner;
        while(owner == -1)
        {
            /* compute side */
            int split_dim = current_node -> split_dim;
            int side = data[split_dim] > current_node -> split_val;
            switch (side) 
            {
                case TOP_TREE_RCH:
                    {
                        current_node = current_node -> rch;                    
                    }
                    break;
    
                case TOP_TREE_LCH:
                    {
                        current_node = current_node -> lch;                    
                    }
                    break;
                default:
                    break;
            }
            owner = current_node -> owner;
        }
        return owner;
    }
    
    /* to partition points around owners */
    int partition_data_around_key(int* key, float_t *val, int vec_len, int ref_key , int left, int right) 
    {
        /*
        * returns the number of elements less than the pivot
        */
        int store_index = left;
        int i;
        /* Move pivot to end */
        for (i = left; i < right; ++i) 
        {
            // if(compare_data_element(array + i*vec_len, array + pivot_index*vec_len, compare_dim ) >= 0){
            if (key[i] < ref_key) 
            {
                swap_data_element(val + store_index * vec_len, val + i * vec_len, vec_len);
                /* swap keys */
                int tmp = key[i];
                key[i] = key[store_index];
                key[store_index] = tmp;
                
                store_index += 1;
            }
        }
        /* Move pivot to its final place */
        // swap_data_element(array + (store_index)*vec_len , array + right*vec_len,
        // vec_len);
    
        return store_index; // maybe, update, it works :)
    }
    
    
    
    void exchange_points(global_context_t* ctx, top_kdtree_t* tree)
    {
        int* points_per_proc = (int*)malloc(ctx -> world_size * sizeof(int));    
        int* points_owners      = (int*)malloc(ctx -> dims * ctx -> local_n_points * sizeof(float_t));
        int* partition_offset = (int*)malloc(ctx -> world_size * sizeof(int));    
    
        /* compute owner */
        #pragma omp parallel for
        for(size_t i = 0; i < ctx -> local_n_points; ++i)
        {
            /* tree walk */
            points_owners[i] = compute_point_owner(ctx, tree, ctx -> local_data + (i * ctx -> dims));
        }
            
        
        int last_idx = 0;
        int len      = ctx -> local_n_points;
        float_t* curr_data = ctx -> local_data;
    
        partition_offset[0] = 0;
        for(int owner = 1; owner < ctx -> world_size; ++owner)
        {
            last_idx = partition_data_around_key(points_owners, ctx -> local_data, ctx -> dims, owner, last_idx, ctx -> local_n_points);    
            partition_offset[owner] = last_idx;
            points_per_proc[owner - 1] = last_idx;
        }
    
        points_per_proc[ctx -> world_size - 1] = ctx -> local_n_points;
        
        
        for(int i = ctx -> world_size - 1; i > 0; --i)
        {
            points_per_proc[i] = points_per_proc[i] - points_per_proc[i - 1];
        }
    
        int* rcv_count      = (int*)malloc(ctx -> world_size * sizeof(int));
        int* rcv_displs     = (int*)malloc(ctx -> world_size * sizeof(int));
        int* send_displs    = (int*)malloc(ctx -> world_size * sizeof(int)); 
        int* send_count     = points_per_proc;
    
        float_t* rcvbuffer = NULL;
        int tot_count = 0;
    
        //[-8.33416939 -8.22858047]
        
    
        MPI_Barrier(ctx -> mpi_communicator);
    
        /* TODO: change it to an all to all*/
        MPI_Alltoall(send_count, 1, MPI_INT, rcv_count, 1, MPI_INT, ctx -> mpi_communicator);
        rcv_displs[0] = 0;
        send_displs[0] = 0;
        for(int i = 1; i < ctx -> world_size; ++i) 
        {
            rcv_displs[i] = rcv_displs[i - 1] + rcv_count[i - 1];
            send_displs[i] = send_displs[i - 1] + send_count[i - 1];
        }
    
        /*multiply for number of elements */
        for(int i = 0; i < ctx -> world_size; ++i) 
        {
            send_displs[i]= send_displs[i] * ctx -> dims;
            send_count[i] = send_count[i] * ctx -> dims;
    
            rcv_displs[i]= rcv_displs[i] * ctx -> dims;
            rcv_count[i] = rcv_count[i] * ctx -> dims;
            tot_count += rcv_count[i];
        }
    
        rcvbuffer = (float_t*)malloc(tot_count * sizeof(float_t));
    
        /*exchange points */
    
        MPI_Alltoallv(  ctx -> local_data, send_count, send_displs, MPI_MY_FLOAT, 
                        rcvbuffer, rcv_count, rcv_displs, MPI_MY_FLOAT, 
                        ctx -> mpi_communicator);
    
        ctx -> local_n_points = tot_count / ctx -> dims; 
        int* ppp = (int*)malloc(ctx -> world_size * sizeof(int));
    
        MPI_Allgather(&(ctx -> local_n_points), 1, MPI_INT, ppp, 1, MPI_INT, ctx -> mpi_communicator);
        ctx -> idx_start = 0;
        for(int i = 0; i < ctx -> mpi_rank; ++i)
        {
            ctx -> idx_start += ppp[i];
        }
    
        /* find slices of indices */
        for(int i = 0; i < ctx -> world_size; ++i) ctx -> rank_n_points[i] = ppp[i]; 
    
        ctx -> rank_idx_start[0] = 0;
        for(int i = 1; i < ctx -> world_size; ++i) ctx -> rank_idx_start[i] = ppp[i - 1] + ctx -> rank_idx_start[i - 1]; 
    
        /* free prv pointer */
        free(ppp);
        free(ctx -> local_data);
        ctx -> local_data = rcvbuffer;
    
        /* check exchange */
        
        for(size_t i = 0; i < ctx -> local_n_points; ++i)
        {
            int o = compute_point_owner(ctx, tree, ctx -> local_data + (i * ctx -> dims));
            if(o != ctx -> mpi_rank) DB_PRINT("rank %d got an error\n",ctx -> mpi_rank);
        }
    
    
    
        free(points_owners);
        free(points_per_proc);
        free(partition_offset);
        free(rcv_count);
        free(rcv_displs);
        free(send_displs);
    }
    
    static inline size_t local_to_global_idx(global_context_t* ctx, size_t local_idx)
    {
        return local_idx + ctx -> idx_start; 
    }
    
    void translate_tree_idx_to_global(global_context_t* ctx, kdtree_v2* local_tree) 
    {
        for(size_t i = 0; i < ctx -> local_n_points; ++i)        
        {
            local_tree -> _nodes[i].array_idx = local_to_global_idx(ctx, local_tree -> _nodes[i].array_idx); 
        }
    }
    
    heap ngbh_search_kdtree(global_context_t* ctx, kdtree_v2* local_tree, float_t* data, int k)
    {
        data_dims = ctx -> dims;
        return knn_kdtree_v2(data, local_tree -> root, k);
    }
    
    void tree_walk(
            global_context_t* ctx, 
            top_kdtree_node_t* root, 
            int point_idx,
            float_t max_dist,
            float_t* point,
            float_t** data_to_send_per_proc, 
            int** local_idx_of_the_point, 
            int* point_to_send_count, 
            int* point_to_send_capacity)
    {
        if(root -> owner != -1 && root -> owner != ctx -> mpi_rank)
        {
            
            #pragma omp critical
            {
                /* put the leaf on the requests array */
                int owner = root -> owner;
                int idx = point_to_send_count[owner];
                int capacity = point_to_send_capacity[owner];
                //int len = 1 + ctx -> dims;
                int len = ctx -> dims;
                if(idx == capacity)
                {
                    //data_to_send_per_proc[owner]  = realloc(data_to_send_per_proc[owner], (capacity * 1.1) * (1 + ctx -> dims) * sizeof(float_t));
                    data_to_send_per_proc[owner]  = realloc(data_to_send_per_proc[owner], (capacity * 1.1) * (ctx -> dims) * sizeof(float_t));
                    local_idx_of_the_point[owner] = realloc(local_idx_of_the_point[owner], (capacity * 1.1) * sizeof(int));
                    point_to_send_capacity[owner] = capacity * 1.1;
                }
    
                float_t* base = data_to_send_per_proc[owner] + (len * idx); 
                /*
                base[0] = max_dist;
                memcpy(base + 1, point, ctx -> dims * sizeof(float_t));
                */
                memcpy(base, point, ctx -> dims * sizeof(float_t));
                local_idx_of_the_point[owner][idx] = point_idx;
    
                point_to_send_count[owner]++;
            }
    
        }
        else
        {
            /* tree walk */
            int split_var = root -> split_dim;
            float_t hp_distance = point[split_var] - root -> split_val;
            __builtin_prefetch(root -> lch, 0, 3);
            __builtin_prefetch(root -> rch, 0, 3);
    
            int side = hp_distance > 0.f;
    
            switch (side)
            {
                case TOP_TREE_LCH:
                    if(root -> lch)
                    {
                        /* walk on the left */
                        tree_walk(ctx, root -> lch, point_idx, max_dist, point, 
                                data_to_send_per_proc, local_idx_of_the_point, 
                                point_to_send_count, point_to_send_capacity);
                    }
                    break;
                
                case TOP_TREE_RCH:
                    if(root -> rch)
                    {
                        /* walk on the right */
                        tree_walk(ctx, root -> rch, point_idx, max_dist, point, 
                                data_to_send_per_proc, local_idx_of_the_point, 
                                point_to_send_count, point_to_send_capacity);
                    }
                    break;
    
                default:
                    break;
            }
    
            int c   = max_dist > (hp_distance * hp_distance);
    
            //if(c || (H -> count) < (H -> N))
            if(c)
            {
    
                switch (side)
                {
                    case HP_LEFT_SIDE:
                        if(root -> rch) 
                        {
                            /* walk on the right */
                            tree_walk(ctx, root -> rch, point_idx, max_dist, point, 
                                    data_to_send_per_proc, local_idx_of_the_point, 
                                    point_to_send_count, point_to_send_capacity);
                        }
                        break;
                    
                    case HP_RIGHT_SIDE:
                        if(root -> lch) 
                        {
                            /* walk on the left */
                            tree_walk(ctx, root -> lch, point_idx, max_dist, point, 
                                    data_to_send_per_proc, local_idx_of_the_point, 
                                    point_to_send_count, point_to_send_capacity);
                        }
                        break;
    
                    default:
                        break;
                }
            }
        }
    
    }
    
    void tree_walk_v2_find_n_points(
            global_context_t* ctx, 
            top_kdtree_node_t* root, 
            int point_idx,
            float_t max_dist,
            float_t* point,
            int* point_to_send_capacity) 
    {
        if(root -> owner != -1 && root -> owner != ctx -> mpi_rank)
        {
            #pragma omp atomic update 
            point_to_send_capacity[root -> owner]++;
        }
        else
        {
            /* tree walk */
            int split_var = root -> split_dim;
            float_t hp_distance = point[split_var] - root -> split_val;
            __builtin_prefetch(root -> lch, 0, 3);
            __builtin_prefetch(root -> rch, 0, 3);
    
            int side = hp_distance > 0.f;
    
            switch (side)
            {
                case TOP_TREE_LCH:
                    if(root -> lch)
                    {
                        /* walk on the left */
                        tree_walk_v2_find_n_points(ctx, root -> lch, point_idx, max_dist, point, point_to_send_capacity);
                    }
                    break;
                
                case TOP_TREE_RCH:
                    if(root -> rch)
                    {
                        /* walk on the right */
                        tree_walk_v2_find_n_points(ctx, root -> rch, point_idx, max_dist, point, point_to_send_capacity);
                    }
                    break;
    
                default:
                    break;
            }
    
            int c   = max_dist > (hp_distance * hp_distance);
    
            //if(c || (H -> count) < (H -> N))
            if(c)
            {
    
                switch (side)
                {
                    case HP_LEFT_SIDE:
                        if(root -> rch) 
                        {
                            /* walk on the right */
                            tree_walk_v2_find_n_points(ctx, root -> rch, point_idx, max_dist, point, point_to_send_capacity);
                        }
                        break;
                    
                    case HP_RIGHT_SIDE:
                        if(root -> lch) 
                        {
                            /* walk on the left */
                            tree_walk_v2_find_n_points(ctx, root -> lch, point_idx, max_dist, point, point_to_send_capacity);
                        }
                        break;
    
                    default:
                        break;
                }
            }
        }
    
    }
    
    void tree_walk_v2_append_points(
            global_context_t* ctx, 
            top_kdtree_node_t* root, 
            int point_idx,
            float_t max_dist,
            float_t* point,
            float_t** data_to_send_per_proc, 
            int** local_idx_of_the_point, 
            int* point_to_send_count) 
    {
        if(root -> owner != -1 && root -> owner != ctx -> mpi_rank)
        {
            /* put the leaf on the requests array */
            int owner = root -> owner;
    
    
            int idx;
    
            #pragma omp atomic capture
            idx = point_to_send_count[owner]++;
    
            int len = ctx -> dims;
    
            float_t* base = data_to_send_per_proc[owner] + (len * idx); 
    
            memcpy(base, point, ctx -> dims * sizeof(float_t));
            local_idx_of_the_point[owner][idx] = point_idx;
        }
        else
        {
            /* tree walk */
            int split_var = root -> split_dim;
            float_t hp_distance = point[split_var] - root -> split_val;
            __builtin_prefetch(root -> lch, 0, 3);
            __builtin_prefetch(root -> rch, 0, 3);
    
            int side = hp_distance > 0.f;
    
            switch (side)
            {
                case TOP_TREE_LCH:
                    if(root -> lch)
                    {
                        /* walk on the left */
                        tree_walk_v2_append_points(ctx, root -> lch, point_idx, max_dist, point, 
                                data_to_send_per_proc, local_idx_of_the_point, point_to_send_count);
                    }
                    break;
                
                case TOP_TREE_RCH:
                    if(root -> rch)
                    {
                        /* walk on the right */
                        tree_walk_v2_append_points(ctx, root -> rch, point_idx, max_dist, point, 
                                data_to_send_per_proc, local_idx_of_the_point, point_to_send_count);
                    }
                    break;
    
                default:
                    break;
            }
    
            int c   = max_dist > (hp_distance * hp_distance);
    
            //if(c || (H -> count) < (H -> N))
            if(c)
            {
    
                switch (side)
                {
                    case HP_LEFT_SIDE:
                        if(root -> rch) 
                        {
                            /* walk on the right */
                            tree_walk_v2_append_points(ctx, root -> rch, point_idx, max_dist, point, 
                                    data_to_send_per_proc, local_idx_of_the_point, point_to_send_count);
                        }
                        break;
                    
                    case HP_RIGHT_SIDE:
                        if(root -> lch) 
                        {
                            /* walk on the left */
                            tree_walk_v2_append_points(ctx, root -> lch, point_idx, max_dist, point, 
                                    data_to_send_per_proc, local_idx_of_the_point, point_to_send_count);
                        }
                        break;
    
                    default:
                        break;
                }
            }
        }
    
    }
    
    
    void convert_heap_idx_to_global(global_context_t* ctx, heap* H)
    {
        for(uint64_t i = 0; i < H -> count; ++i)
        {
            H -> data[i].array_idx = local_to_global_idx(ctx, H -> data[i].array_idx);
        }
    }
    
    void print_diagnositcs(global_context_t* ctx, int k)
    {
        MPI_Comm shmcomm;
        MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0,
                            MPI_INFO_NULL, &shmcomm);
        int shm_world_size;
        MPI_Comm_size(shmcomm, &shm_world_size);
        MPI_DB_PRINT("\n");
        MPI_DB_PRINT("[INFO] Got %d ranks per node \n",shm_world_size); 
        /* data */
        float_t memory_use = (float_t)ctx -> local_n_points * ctx -> dims * sizeof(float_t);
        memory_use += (float_t)sizeof(datapoint_info_t)* (float_t)(ctx -> local_n_points); 
        /* ngbh */
        memory_use += (float_t)sizeof(heap_node)*(float_t)k * (float_t)(ctx -> local_n_points); 
        memory_use = memory_use / 1e9 * shm_world_size;
    
        MPI_DB_PRINT("       Got ~%d points per node and %d ngbh per points\n", ctx -> local_n_points * shm_world_size, k); 
        MPI_DB_PRINT("       Expected to use ~%.2lfGB of memory for each node, plus memory required to communicate ngbh\n", memory_use); 
        struct sysinfo info;
        sysinfo(&info);
        
        if(memory_use > 0.5 * (float_t)info.freeram / 1e9)
            MPI_DB_PRINT("/!\\    Projected memory usage is more than half of the node memory, may go into troubles while communicating ngbh\n"); 
        MPI_DB_PRINT("\n");
    
        MPI_Barrier(ctx -> mpi_communicator);
    }
    
    
    void mpi_ngbh_search(global_context_t* ctx, datapoint_info_t* dp_info, top_kdtree_t* top_tree, kdtree_v2* local_tree, float_t* data, int k)
    {
        /* local search */
        /* print diagnostics */
        print_diagnositcs(ctx, k);
        
        TIME_DEF;
        double elapsed_time;
    
        TIME_START;
        MPI_Barrier(ctx -> mpi_communicator);
        #pragma omp parallel for
        for(int p = 0; p < ctx -> local_n_points; ++p)
        {
            idx_t idx = local_tree -> _nodes[p].array_idx;
            /* actually we want to preserve the heap to then insert guesses from other nodes */
            dp_info[idx].ngbh = knn_kdtree_v2_no_heapsort(local_tree -> _nodes[p].data, local_tree -> root, k);
            convert_heap_idx_to_global(ctx, &(dp_info[idx].ngbh));
            dp_info[idx].cluster_idx = -1;
            dp_info[idx].is_center = 0;
            dp_info[idx].array_idx = idx + ctx -> idx_start;
        }
        elapsed_time = TIME_STOP;
        LOG_WRITE("Local neighborhood search", elapsed_time);
    
    
        TIME_START;
        /* find if a points needs a refine on the global tree */
        float_t** data_to_send_per_proc      = (float_t**)malloc(ctx -> world_size * sizeof(float_t*));
        int**       local_idx_of_the_point     = (int**)malloc(ctx -> world_size * sizeof(int*));
        int*       point_to_snd_count       = (int*)malloc(ctx -> world_size * sizeof(int));
        int*       point_to_snd_capacity    = (int*)malloc(ctx -> world_size * sizeof(int));
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            /* allocate it afterwards */
    
            /* OLD VERSION 
            data_to_send_per_proc[i]  = (float_t*)malloc(100 * (ctx -> dims) * sizeof(float_t));    
            local_idx_of_the_point[i] = (int*)malloc(100 * sizeof(int));    
            point_to_snd_capacity[i] = 100;
            */
    
            /* NEW VERSION with double tree walk */
            point_to_snd_capacity[i] = 0;
            point_to_snd_count[i]    = 0;
        }
    
        /* for each point walk the tree and find to which proc send data */
        /* actually compute intersection of ngbh radius of each point to node box */
    
        /* OLD VERSION SINGLE TREE WALK */
        /*
        #pragma omp parallel for
        for(int i = 0; i < ctx -> local_n_points; ++i)
        {
            float_t max_dist = dp_info[i].ngbh.data[0].value;
            float_t* point   = ctx -> local_data + (i * ctx -> dims);
    
            tree_walk(ctx, top_tree -> root, i, max_dist, 
                      point, data_to_send_per_proc, local_idx_of_the_point, 
                      point_to_snd_count, point_to_snd_capacity);
        }
        */
    
        /* NEW VERSION double tree walk */
        #pragma omp parallel for
        for(int i = 0; i < ctx -> local_n_points; ++i)
        {
            float_t max_dist = dp_info[i].ngbh.data[0].value;
            float_t* point   = ctx -> local_data + (i * ctx -> dims);
            
            tree_walk_v2_find_n_points(ctx, top_tree -> root, i, max_dist, point, point_to_snd_capacity);
    
        }
    
        /* allocate needed space */
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            int np = point_to_snd_capacity[i];
            data_to_send_per_proc[i]  = (float_t*)malloc(np * (ctx -> dims) * sizeof(float_t));    
            local_idx_of_the_point[i] = (int*)malloc(np * sizeof(int));    
    
        }
    
    
        #pragma omp parallel for
        for(int i = 0; i < ctx -> local_n_points; ++i)
        {
            float_t max_dist = dp_info[i].ngbh.data[0].value;
            float_t* point   = ctx -> local_data + (i * ctx -> dims);
    
            tree_walk_v2_append_points(ctx, top_tree -> root, i, max_dist, point, data_to_send_per_proc, local_idx_of_the_point, point_to_snd_count);
        }
    
    
        elapsed_time = TIME_STOP;
        LOG_WRITE("Finding points to refine", elapsed_time);
    
        TIME_START;
        int* point_to_rcv_count = (int*)malloc(ctx -> world_size * sizeof(int));
    
        /* exchange points to work on*/
        MPI_Alltoall(point_to_snd_count, 1, MPI_INT, point_to_rcv_count, 1, MPI_INT, ctx -> mpi_communicator);
    
        int* rcv_count = (int*)malloc(ctx -> world_size * sizeof(int));
        int* snd_count = (int*)malloc(ctx -> world_size * sizeof(int));
        int* rcv_displ = (int*)malloc(ctx -> world_size * sizeof(int));
        int* snd_displ = (int*)malloc(ctx -> world_size * sizeof(int));
    
        /*compute counts and displs*/
        rcv_displ[0] = 0;
        snd_displ[0] = 0;
    
    
        rcv_count[0] = point_to_rcv_count[0] * (ctx -> dims);
        snd_count[0] = point_to_snd_count[0] * (ctx -> dims);
    
        int tot_points_rcv = point_to_rcv_count[0];
        int tot_points_snd = point_to_snd_count[0];
        int tot_count = rcv_count[0];
    
        for(int i = 1; i < ctx -> world_size; ++i)
        {
            rcv_count[i] = point_to_rcv_count[i] * (ctx -> dims);        
            snd_count[i] = point_to_snd_count[i] * (ctx -> dims);        
    
            tot_count += rcv_count[i];
            tot_points_rcv += point_to_rcv_count[i];
            tot_points_snd += point_to_snd_count[i];
    
            rcv_displ[i] = rcv_displ[i - 1] + rcv_count[i - 1];
            snd_displ[i] = snd_displ[i - 1] + snd_count[i - 1];
        }
    
        float_t* __rcv_points = (float_t*)malloc(tot_points_rcv * (ctx -> dims) * sizeof(float_t));
        float_t* __snd_points = (float_t*)malloc(tot_points_snd * (ctx -> dims) * sizeof(float_t)); 
    
    
    
        /* copy data to send in contiguous memory */
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            memcpy(__snd_points + snd_displ[i], data_to_send_per_proc[i], snd_count[i] * sizeof(float_t));
        }
    
    
        MPI_Alltoallv(__snd_points, snd_count, snd_displ, MPI_MY_FLOAT, 
                      __rcv_points, rcv_count, rcv_displ, MPI_MY_FLOAT, ctx -> mpi_communicator); 
    
        float_t** rcv_work_batches = (float_t**)malloc(ctx -> world_size * sizeof(float_t*));
        for(int i = 0; i < ctx -> world_size; ++i) 
        {
            //rcv_work_batches[i]       = NULL;
            rcv_work_batches[i]       = __rcv_points + rcv_displ[i];
        }
    
        MPI_Status status;
        int flag;
        /* prepare heap batches */
    
        //int work_batch_stride = 1 + ctx -> dims;
        int work_batch_stride = ctx -> dims;
    
        /* Note that I then have to recieve an equal number of heaps as the one I sent out before */
        heap_node* __heap_batches_to_snd = (heap_node*)malloc((uint64_t)k * (uint64_t)tot_points_rcv * sizeof(heap_node));
        heap_node* __heap_batches_to_rcv = (heap_node*)malloc((uint64_t)k * (uint64_t)tot_points_snd * sizeof(heap_node));
    
        
        if( __heap_batches_to_rcv == NULL)
        {
            DB_PRINT("Rank %d failed to allocate rcv_heaps %luB required\n",ctx -> mpi_rank, (uint64_t)k * (uint64_t)tot_points_rcv * sizeof(heap_node));
            exit(1);
        }
    
        if( __heap_batches_to_snd == NULL)
        {
            DB_PRINT("Rank %d failed to allocate snd_heaps %luB required\n",ctx -> mpi_rank, (uint64_t)k * (uint64_t)tot_points_snd * sizeof(heap_node));
            exit(1);
        }
    
        MPI_Barrier(ctx -> mpi_communicator);
    
        rcv_displ[0] = 0;
        snd_displ[0] = 0;
        rcv_count[0] = point_to_rcv_count[0];
        snd_count[0] = point_to_snd_count[0]; 
    
    
        for(int i = 1; i < ctx -> world_size; ++i)
        {
    
            rcv_count[i] = point_to_rcv_count[i]; 
            snd_count[i] = point_to_snd_count[i]; 
    
            rcv_displ[i] = rcv_displ[i - 1] + rcv_count[i - 1];
            snd_displ[i] = snd_displ[i - 1] + snd_count[i - 1];
        }
    
    
        heap_node** heap_batches_per_node = (heap_node**)malloc(ctx -> world_size * sizeof(heap_node*));
        for(int p = 0; p < ctx -> world_size; ++p) 
        {
            heap_batches_per_node[p] = __heap_batches_to_snd + (uint64_t)rcv_displ[p] * (uint64_t)k;
        }
    
        /* compute everything */
        elapsed_time = TIME_STOP;
        LOG_WRITE("Exchanging points", elapsed_time);
        MPI_Barrier(ctx -> mpi_communicator);
    
    
        TIME_START;
    
        /* ngbh search on recieved points */
        for(int p = 0; p < ctx -> world_size; ++p)
        {
            if(point_to_rcv_count[p] > 0 && p != ctx -> mpi_rank)
            //if(count_rcv_work_batches[p] > 0)
            {
                //heap_batches_per_node[p] = (heap_node*)malloc(k * point_to_rcv_count[p] * sizeof(heap_node));
                #pragma omp parallel for
                for(int batch = 0; batch < point_to_rcv_count[p]; ++batch)
                {
                    heap H;
                    H.count = 0;
                    H.N = k;
                    H.data = heap_batches_per_node[p] + (uint64_t)k * (uint64_t)batch; 
                    init_heap(&H);
                    //float_t* point = rcv_work_batches[p] + batch * work_batch_stride + 1; 
                    float_t* point = rcv_work_batches[p] + (uint64_t)batch * (uint64_t)work_batch_stride; 
                    knn_sub_tree_search_kdtree_v2(point, local_tree -> root, &H);
                    convert_heap_idx_to_global(ctx, &H);
                }
            }
        }
    
        /* sendout results */
    
        /* 
         * dummy pointers to clarify counts in this part
         * act like an alias for rcv and snd counts
         */ 
    
        int* ngbh_to_send = point_to_rcv_count;
        int* ngbh_to_recv = point_to_snd_count;
    
        /*
         * counts are inverted since I have to recieve as many batches as points I
         * Have originally sended
         */
    
        elapsed_time = TIME_STOP;
        LOG_WRITE("Ngbh search for foreing points", elapsed_time);
    
        TIME_START;
        
        MPI_Datatype MPI_my_heap;
        MPI_Type_contiguous(k * sizeof(heap_node), MPI_CHAR, &MPI_my_heap);
        MPI_Barrier(ctx -> mpi_communicator);
        MPI_Type_commit(&MPI_my_heap);
    
        heap_node** rcv_heap_batches = (heap_node**)malloc(ctx -> world_size * sizeof(heap_node*));
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            rcv_heap_batches[i] = __heap_batches_to_rcv + snd_displ[i] * k;
        }
    
        /* -------------------------------------
         * ALTERNATIVE TO ALL TO ALL FOR BIG MSG
         * HERE IT BREAKS mpi cannot handle msg
         * lager than 4GB
         * ------------------------------------- */
        
        MPI_Barrier(ctx -> mpi_communicator);
        int default_msg_len = MAX_MSG_SIZE / (k * sizeof(heap_node));
    
        int* already_sent_points = (int*)malloc(ctx -> world_size * sizeof(int));
        int* already_rcvd_points = (int*)malloc(ctx -> world_size * sizeof(int));
    
        /* allocate a request array to keep track of all requests going out*/
        MPI_Request* req_array;
        int req_num = 0;
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            req_num += ngbh_to_send[i] > 0 ? ngbh_to_send[i]/default_msg_len + 1 : 0;
        }
    
        req_array = (MPI_Request*)malloc(req_num * sizeof(MPI_Request));
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            already_sent_points[i] = 0;
            already_rcvd_points[i] = 0;
        }
    
        int req_idx = 0;
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            int count = 0;
            if(ngbh_to_send[i] > 0)
            {
                while(already_sent_points[i] < ngbh_to_send[i])
                {
                    MPI_Request request;
                    count = MIN(default_msg_len, ngbh_to_send[i] - already_sent_points[i] );
                    MPI_Isend(  heap_batches_per_node[i] + k * already_sent_points[i], count,  
                            MPI_my_heap, i, 0, ctx -> mpi_communicator, &request);
                    already_sent_points[i] += count;
                    req_array[req_idx] = request;
                    ++req_idx;
                }
            }
        }
        
        MPI_Barrier(ctx -> mpi_communicator);
        MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, ctx -> mpi_communicator, &flag, &status);
        //DB_PRINT("%d %p %p\n",ctx -> mpi_rank, &flag, &status);
        //HERE
        while(flag)
        {
            MPI_Request request;
            int count; 
            int source = status.MPI_SOURCE;
            MPI_Get_count(&status, MPI_my_heap, &count);
            /* recieve each slice */
    
            MPI_Recv(rcv_heap_batches[source] + k * already_rcvd_points[source], 
                    count, MPI_my_heap, source, MPI_ANY_TAG, ctx -> mpi_communicator, &status);
    
            already_rcvd_points[source] += count;
            MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, ctx -> mpi_communicator, &flag, &status);
    
        }
        MPI_Barrier(ctx -> mpi_communicator);
    
    
        MPI_Testall(req_num, req_array, &flag, MPI_STATUSES_IGNORE);
    
        if(flag == 0)
        {
            DB_PRINT("[!!!] Rank %d has unfinished communications\n", ctx -> mpi_rank);
            exit(1);
        }
        free(req_array);
        free(already_sent_points);
        free(already_rcvd_points);
    
        elapsed_time = TIME_STOP;
        LOG_WRITE("Sending results to other proc", elapsed_time);
    
        /* merge old with new heaps */
    
        MPI_Barrier(ctx -> mpi_communicator);
    
        TIME_START;
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            #pragma omp paralell for
            for(int b = 0; b < ngbh_to_recv[i]; ++b)
            {
                int idx = local_idx_of_the_point[i][b];
                /* retrieve the heap */
                heap H;
                H.count = k;
                H.N     = k;
                H.data  = rcv_heap_batches[i] + k*b;
                /* insert the points into the heap */
                for(int j = 0; j < k; ++j)
                {
                    insert_max_heap(&(dp_info[idx].ngbh), H.data[j].value, H.data[j].array_idx);
                }
            }
        }
        
        /* heapsort them */
    
        #pragma omp parallel for
        for(int i = 0; i < ctx -> local_n_points; ++i)
        {
            heap_sort(&(dp_info[i].ngbh));
        }
    
        elapsed_time = TIME_STOP;
        LOG_WRITE("Merging results", elapsed_time);
    
        #if defined(WRITE_NGBH)
        MPI_DB_PRINT("Writing ngbh to files\n");
            char ngbh_out[80];
            sprintf(ngbh_out, "./bb/rank_%d.ngbh",ctx -> mpi_rank);
            FILE* file = fopen(ngbh_out,"w");
            if(!file) 
            {
                printf("Cannot open file %s\n",ngbh_out);
            }
            else
            {
                for(int i = 0; i < ctx -> local_n_points; ++i)
                {
                    fwrite(dp_info[i].ngbh.data, sizeof(heap_node), k, file);
                }
                fclose(file);
            }
        #endif
    
        MPI_Barrier(ctx -> mpi_communicator);
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            if(data_to_send_per_proc[i])  free(data_to_send_per_proc[i]);
            if(local_idx_of_the_point[i]) free(local_idx_of_the_point[i]);
        }
    
    
        free(data_to_send_per_proc);
        free(local_idx_of_the_point);
        free(heap_batches_per_node);
        free(rcv_heap_batches);
        free(rcv_work_batches);
        free(point_to_rcv_count);
        free(point_to_snd_count);
        free(point_to_snd_capacity);
    
        free(rcv_count);
        free(snd_count);
        free(rcv_displ);
        free(snd_displ);
        free(__heap_batches_to_rcv);
        free(__heap_batches_to_snd);
        free(__rcv_points);
        free(__snd_points);
    
    }
    
    void test_the_idea(global_context_t* ctx)
    {
        int* num = (int*)malloc(ctx -> world_size * sizeof(int));
        for(int i = 0; i < ctx -> world_size; ++i) num[i] = i;
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            MPI_Request request;
            MPI_Isend(num + i, 1, MPI_INT, i, 0, ctx -> mpi_communicator, &request);
        }
        MPI_Barrier(ctx -> mpi_communicator);
        MPI_Status status;
        int flag;
        int cc = 0;
        MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_SOURCE, ctx -> mpi_communicator, &flag, &status);
        while(flag)
        {
            cc++;
            MPI_Request request;
            MPI_Recv(num + status.MPI_SOURCE, 1, MPI_INT, status.MPI_SOURCE, MPI_ANY_TAG, ctx -> mpi_communicator, &status);
            MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_SOURCE, ctx -> mpi_communicator, &flag, &status);
    
        }
        MPI_DB_PRINT("Recieved %d msgs\n",cc);
        free(num);
    
    }
    
    void build_local_tree(global_context_t* ctx, kdtree_v2* local_tree)
    {
        local_tree -> root = build_tree_kdtree_v2(local_tree -> _nodes, local_tree -> n_nodes, ctx -> dims);
    }
    
    void ordered_data_to_file(global_context_t* ctx)
    {
        //MPI_Barrier(ctx -> mpi_communicator);
        MPI_DB_PRINT("[MASTER] writing to file\n");
        float_t* tmp_data; 
        int* ppp; 
        int* displs;
    
        MPI_Barrier(ctx -> mpi_communicator);
        if(I_AM_MASTER) 
        {
            tmp_data = (float_t*)malloc(ctx -> dims * ctx -> n_points * sizeof(float_t));
            ppp      = (int*)malloc(ctx -> world_size * sizeof(int));
            displs   = (int*)malloc(ctx -> world_size * sizeof(int));
    
        }
        
        MPI_Gather(&(ctx -> local_n_points), 1, MPI_INT, ppp, 1, MPI_INT, 0, ctx -> mpi_communicator);
    
        if(I_AM_MASTER)
        {
            displs[0] = 0;
            for(int i = 0; i < ctx -> world_size; ++i) ppp[i]    = ctx -> dims * ppp[i];
            for(int i = 1; i < ctx -> world_size; ++i) displs[i] = displs[i - 1] + ppp[i - 1];
                
        }
        MPI_Gatherv(ctx -> local_data, ctx -> dims * ctx -> local_n_points, 
                MPI_MY_FLOAT, tmp_data, ppp, displs, MPI_MY_FLOAT, 0, ctx -> mpi_communicator);
    
        if(I_AM_MASTER)
        {
            FILE* file = fopen("bb/ordered_data.npy","w");
            fwrite(tmp_data, sizeof(float_t), ctx -> dims * ctx -> n_points, file);
            fclose(file);
            free(tmp_data);
            free(ppp);
            free(displs);
        }
        MPI_Barrier(ctx -> mpi_communicator);
    }
    
    void ordered_buffer_to_file(global_context_t* ctx, void* buffer, size_t el_size, uint64_t n, const char* fname)
    {
        //MPI_Barrier(ctx -> mpi_communicator);
        MPI_DB_PRINT("[MASTER] writing to file\n");
        void* tmp_data; 
        int* ppp; 
        int* displs;
    
        MPI_Barrier(ctx -> mpi_communicator);
        
        uint64_t tot_n = 0;
        MPI_Reduce(&n, &tot_n, 1, MPI_UINT64_T , MPI_SUM, 0, ctx -> mpi_communicator);
    
        if(I_AM_MASTER) 
        {
            tmp_data = (void*)malloc(el_size * tot_n );
            ppp      = (int*)malloc(ctx -> world_size * sizeof(int));
            displs   = (int*)malloc(ctx -> world_size * sizeof(int));
    
        }
        
        int nn = (int)n;
        MPI_Gather(&nn, 1, MPI_INT, ppp, 1, MPI_INT, 0, ctx -> mpi_communicator);
    
        if(I_AM_MASTER)
        {
            displs[0] = 0;
            for(int i = 0; i < ctx -> world_size; ++i) ppp[i]    = el_size  * ppp[i];
            for(int i = 1; i < ctx -> world_size; ++i) displs[i] = displs[i - 1] + ppp[i - 1];
                
        }
    
        MPI_Gatherv(buffer, (int)(el_size * n), 
                MPI_CHAR, tmp_data, ppp, displs, MPI_CHAR, 0, ctx -> mpi_communicator);
    
        if(I_AM_MASTER)
        {
            FILE* file = fopen(fname,"w");
            fwrite(tmp_data, 1, el_size * tot_n, file);
            fclose(file);
            free(tmp_data);
            free(ppp);
            free(displs);
    
        }
        MPI_Barrier(ctx -> mpi_communicator);
    }
    
    static inline int foreign_owner(global_context_t* ctx, idx_t idx)
    {
        int owner = ctx -> mpi_rank;
        if( idx >= ctx -> idx_start && idx < ctx -> idx_start + ctx -> local_n_points) 
        {
            return ctx -> mpi_rank;
        }
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            owner = i;    
            if( idx >= ctx -> rank_idx_start[i] && idx < ctx -> rank_idx_start[i] + ctx -> rank_n_points[i]) break;
        }
        return owner;
    }
    
    static inline void append_foreign_idx_list(idx_t element, int owner, int* counts, idx_t** lists)
    {
    
    
        /* find the plausible place */
        int idx_to_insert;
    
        #pragma omp atomic capture
        idx_to_insert = counts[owner]++; 
        
        lists[owner][idx_to_insert] = element;
    }
    
    int cmp_idx(const void* a, const void* b)
    {
        idx_t aa = *((idx_t*)a);
        idx_t bb = *((idx_t*)b);
        return (aa > bb) - (aa < bb);
    }
    
    void find_foreign_nodes(global_context_t* ctx, datapoint_info_t* dp, datapoint_info_t** foreign_dp)
    {
        int k = dp[0].ngbh.count;
        
        idx_t** array_indexes_to_request = (idx_t**)malloc(ctx -> world_size * sizeof(idx_t*));
        int*    count_to_request         = (int*)malloc(ctx -> world_size * sizeof(int));
        int*    capacities               = (int*)malloc(ctx -> world_size * sizeof(int));
    
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            capacities[i] = 0;
            count_to_request[i] = 0;
        }
    
        /* count them */
        
        #pragma omp parallel for
        for(uint32_t i = 0; i < ctx -> local_n_points; ++i)
        {
            for(int j = 0; j < k; ++j)
            {
                idx_t element = dp[i].ngbh.data[j].array_idx;        
                int owner = foreign_owner(ctx, element);
                //DB_PRINT("%lu %d\n", element, owner);
                
                if(owner != ctx -> mpi_rank)
                {
                    #pragma  omp atomic update
                    capacities[owner]++;
                }
            }
        }
    
        /* alloc */
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            array_indexes_to_request[i] = (idx_t*)malloc(capacities[i] * sizeof(idx_t));
        }
    
        /* append them */
        
        #pragma omp parallel for
        for(uint32_t i = 0; i < ctx -> local_n_points; ++i)
        {
            for(int j = 0; j < k; ++j)
            {
                idx_t element = dp[i].ngbh.data[j].array_idx;        
                int owner = foreign_owner(ctx, element);
                //DB_PRINT("%lu %d\n", element, owner);
                
                if(owner != ctx -> mpi_rank)
                {
                    append_foreign_idx_list(element, owner, count_to_request, array_indexes_to_request); 
                }
            }
        }
    
        /* prune them */
        int* unique_count = (int*)malloc(ctx -> world_size * sizeof(int));
    
        /*
        if(I_AM_MASTER)
        {
            FILE* f = fopen("uniq","w");
            fwrite(array_indexes_to_request[1], sizeof(idx_t), capacities[1],f);
            fclose(f);
        }
        */
    
        #pragma omp paralell for
        for(int i = 0; i < ctx -> world_size; ++i)
        {
           unique_count[i] = capacities[i] > 0; //init unique count 
           qsort(array_indexes_to_request[i], capacities[i], sizeof(idx_t), cmp_idx); 
           uint32_t prev = array_indexes_to_request[i][0];
           for(int el = 1; el < capacities[i]; ++el)
           {
                int flag = prev == array_indexes_to_request[i][el];
                if(!flag)
                {
                    /* in place subsitution 
                     * if a different element is found then 
                     * copy in at the next free place
                     * */
                    array_indexes_to_request[i][unique_count[i]] = array_indexes_to_request[i][el];
                    unique_count[i]++;
                }
                prev = array_indexes_to_request[i][el];
           }
        }
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            array_indexes_to_request[i] = (idx_t*)realloc(array_indexes_to_request[i], unique_count[i] * sizeof(idx_t));
        }
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            foreign_dp[i] = (datapoint_info_t*)calloc(sizeof(datapoint_info_t), unique_count[i]);
        }
        
        /* alias for unique counts */
        int* n_heap_to_recv = unique_count;
        int* n_heap_to_send = (int*)malloc(ctx -> world_size * sizeof(int));
    
        /* exchange how many to recv how many to send */
    
        MPI_Alltoall(n_heap_to_recv, 1, MPI_INT, n_heap_to_send, 1, MPI_INT , ctx -> mpi_communicator);
    
        /* compute displacements and yada yada */
        int* sdispls = (int*)calloc(ctx -> world_size , sizeof(int));
        int* rdispls = (int*)calloc(ctx -> world_size , sizeof(int));
        
        int tot_count_send = n_heap_to_send[0];
        int tot_count_recv = n_heap_to_recv[0];
        for(int i = 1; i < ctx -> world_size; ++i)
        {
            sdispls[i] = sdispls[i - 1] + n_heap_to_send[i - 1];
            rdispls[i] = rdispls[i - 1] + n_heap_to_recv[i - 1];
    
            tot_count_send += n_heap_to_send[i];
            tot_count_recv += n_heap_to_recv[i];
        }
        idx_t* idx_buffer_to_send = (idx_t*)malloc(tot_count_send * sizeof(idx_t));
        idx_t* idx_buffer_to_recv = (idx_t*)malloc(tot_count_recv * sizeof(idx_t));
        
        /* copy indexes on send buffer */
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            memcpy(idx_buffer_to_recv + rdispls[i], array_indexes_to_request[i], n_heap_to_recv[i] * sizeof(idx_t));
        }
        
        MPI_Alltoallv(idx_buffer_to_recv, n_heap_to_recv, rdispls, MPI_UINT64_T, idx_buffer_to_send, n_heap_to_send, sdispls, MPI_UINT64_T, ctx -> mpi_communicator);
    
    
        /* allocate foreign dp */ 
        heap_node* heap_buffer_to_send = (heap_node*)malloc(tot_count_send * k * sizeof(heap_node));
        heap_node* heap_buffer_to_recv = (heap_node*)malloc(tot_count_recv * k * sizeof(heap_node));
    
        for(int i = 0; i < tot_count_send; ++i)
        {
            idx_t idx = idx_buffer_to_send[i] - ctx -> idx_start;
            memcpy(heap_buffer_to_send + i * k, dp[idx].ngbh.data, k * sizeof(heap_node));
        }
        /* exchange heaps */
    
    
        MPI_Barrier(ctx -> mpi_communicator);
        int default_msg_len = MAX_MSG_SIZE / (k * sizeof(heap_node));
    
        int* already_sent_points = (int*)malloc(ctx -> world_size * sizeof(int));
        int* already_rcvd_points = (int*)malloc(ctx -> world_size * sizeof(int));
    
        /* allocate a request array to keep track of all requests going out*/
        MPI_Request* req_array;
        int req_num = 0;
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            req_num += n_heap_to_send[i] > 0 ? n_heap_to_send[i]/default_msg_len + 1 : 0;
        }
    
        req_array = (MPI_Request*)malloc(req_num * sizeof(MPI_Request));
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            already_sent_points[i] = 0;
            already_rcvd_points[i] = 0;
        }
    
        int req_idx = 0;
        
        MPI_Datatype MPI_my_heap;
        MPI_Type_contiguous(k * sizeof(heap_node), MPI_CHAR, &MPI_my_heap);
        MPI_Barrier(ctx -> mpi_communicator);
        MPI_Type_commit(&MPI_my_heap);
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            int count = 0;
            if(n_heap_to_send[i] > 0)
            {
                while(already_sent_points[i] < n_heap_to_send[i])
                {
                    MPI_Request request;
                    count = MIN(default_msg_len, n_heap_to_send[i] - already_sent_points[i] );
                    MPI_Isend(  heap_buffer_to_send + k * (already_sent_points[i] + sdispls[i]), count,  
                            MPI_my_heap, i, 0, ctx -> mpi_communicator, &request);
                    already_sent_points[i] += count;
                    req_array[req_idx] = request;
                    ++req_idx;
                }
            }
        }
        int flag; 
        MPI_Status status;
        MPI_Barrier(ctx -> mpi_communicator);
        MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, ctx -> mpi_communicator, &flag, &status);
        //DB_PRINT("%d %p %p\n",ctx -> mpi_rank, &flag, &status);
        while(flag)
        {
            MPI_Request request;
            int count; 
            int source = status.MPI_SOURCE;
            MPI_Get_count(&status, MPI_my_heap, &count);
            /* recieve each slice */
    
            MPI_Recv(heap_buffer_to_recv + k * (already_rcvd_points[source] + rdispls[source]), 
                    count, MPI_my_heap, source, MPI_ANY_TAG, ctx -> mpi_communicator, &status);
    
            already_rcvd_points[source] += count;
            MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, ctx -> mpi_communicator, &flag, &status);
    
        }
        MPI_Barrier(ctx -> mpi_communicator);
    
        MPI_Type_free(&MPI_my_heap);
    
        MPI_Testall(req_num, req_array, &flag, MPI_STATUSES_IGNORE);
    
        if(flag == 0)
        {
            DB_PRINT("[!!!] Rank %d has unfinished communications\n", ctx -> mpi_rank);
            exit(1);
        }
        free(req_array);
        free(already_sent_points);
        free(already_rcvd_points);
    
    
        /* copy results where needed */
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            for(int j = 0; j < n_heap_to_recv[i]; ++j)
            {
                /*
                foreign_dp[i][j].array_idx = array_indexes_to_request[i][j];
                init_heap(&(foreign_dp[i][j].ngbh));
                allocate_heap(&(foreign_dp[i][j].ngbh), k);
                foreign_dp[i][j].ngbh.N     = k;
                foreign_dp[i][j].ngbh.count = k;
                memcpy(foreign_dp[i][j].ngbh.data, heap_buffer_to_recv + k * (j + rdispls[i]), k * sizeof(heap_node));
                */
    
                foreign_dp[i][j].array_idx = array_indexes_to_request[i][j];
                //init_heap(&(foreign_dp[i][j].ngbh));
                foreign_dp[i][j].ngbh.N     = k;
                foreign_dp[i][j].ngbh.count = k;
                foreign_dp[i][j].ngbh.data =  heap_buffer_to_recv + k * (j + rdispls[i]);
    
                if(foreign_dp[i][j].ngbh.data[0].array_idx != array_indexes_to_request[i][j])
                {
                    printf("Error on %lu\n",array_indexes_to_request[i][j]);
                }
            }
        }
        
    
        /* put back indexes in the context */
    
        ctx -> idx_halo_points_recv = array_indexes_to_request; 
        ctx -> n_halo_points_recv   = n_heap_to_recv;
    
        ctx -> n_halo_points_send = n_heap_to_send;
        ctx -> idx_halo_points_send = (idx_t**)malloc(ctx -> world_size * sizeof(idx_t*));
        for(int i = 0; i < ctx -> world_size; ++i) ctx -> idx_halo_points_send[i] = NULL;
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            ctx -> idx_halo_points_send[i] = (idx_t*)malloc( n_heap_to_send[i] * sizeof(idx_t));
            memcpy(ctx -> idx_halo_points_send[i], idx_buffer_to_send + sdispls[i], n_heap_to_send[i] * sizeof(idx_t) ); 
        }
    
        ctx -> halo_datapoints = foreign_dp;
        ctx -> local_datapoints = dp;
    
        free(count_to_request);
        free(capacities);
    
        /* free(heap_buffer_to_recv); this needs to be preserved*/
        ctx -> __recv_heap_buffers = heap_buffer_to_recv;
        free(heap_buffer_to_send);
        free(idx_buffer_to_send);
        free(idx_buffer_to_recv);
    
        free(sdispls);
        free(rdispls);
    }
    
    float_t mEst2(float_t * x, float_t *y, idx_t n)
    {
    
        /*
         * Estimate the m coefficient of a straight 
         * line passing through the origin          
         * params:                                  
         * - x: x values of the points              
         * - y: y values of the points              
         * - n: size of the arrays                  
         */
         
        float_t num = 0;
        float_t den = 0;
        float_t dd;
        for(idx_t i = 0; i < n; ++i)
        {
            float_t xx = x[i];
            float_t yy = y[i];
    
            dd = xx;
            num += dd*yy;
            den += dd*dd;
    
        }
      
        return num/den;
    }
    
    float_t compute_ID_two_NN_ML(global_context_t* ctx, datapoint_info_t* dp_info, idx_t n, int verbose)
    {
    
        /*
         * Estimation of the intrinsic dimension of a dataset                                       
         * args:                                                                                    
         * - dp_info: array of structs                                                             
         * - n: number of dp_info                                                                  
         * intrinsic_dim = (N - 1) / np.sum(log_mus)
         */
    
        struct timespec start_tot, finish_tot;
        double elapsed_tot;
    
    	if(verbose) 
        {
    		printf("ID estimation:\n");
    		clock_gettime(CLOCK_MONOTONIC, &start_tot);
    	}
        
        float_t log_mus = 0;
        for(idx_t i = 0; i < n; ++i)
        {
            log_mus += 0.5 * log(dp_info[i].ngbh.data[2].value/dp_info[i].ngbh.data[1].value);
        }
    
        float_t d = 0;
        MPI_Allreduce(&log_mus, &d, 1, MPI_MY_FLOAT, MPI_SUM, ctx -> mpi_communicator);
        d = (ctx -> n_points - 1)/d;
    	if(verbose)
    	{
    		clock_gettime(CLOCK_MONOTONIC, &finish_tot);
    		elapsed_tot = (finish_tot.tv_sec - start_tot.tv_sec);
    		elapsed_tot += (finish_tot.tv_nsec - start_tot.tv_nsec) / 1000000000.0;
    		printf("\tID value: %.6lf\n", d);
    		printf("\tTotal time: %.3lfs\n\n", elapsed_tot);
    	}
    
        return d;
    
    }
    
    
    float_t id_estimate(global_context_t* ctx, datapoint_info_t* dp_info, idx_t n, float_t fraction, int verbose)
    {
    
        /*
         * Estimation of the intrinsic dimension of a dataset                                       
         * args:                                                                                    
         * - dp_info: array of structs                                                             
         * - n: number of dp_info                                                                  
         * Estimates the id via 2NN method. Computation of the log ratio of the                      
         * distances of the first 2 neighbors of each point. Then compute the empirical distribution 
         * of these log ratii                                                                        
         * The intrinsic dimension is found by fitting with a straight line passing through the      
         * origin                                                                                    
         */
    
        struct timespec start_tot, finish_tot;
        double elapsed_tot;
    
    	if(verbose)
    	{
    		printf("ID estimation:\n");
    		clock_gettime(CLOCK_MONOTONIC, &start_tot);
    	}
    
        //float_t fraction = 0.7;
        float_t* r = (float_t*)malloc(n*sizeof(float_t));
        float_t* Pemp = (float_t*)malloc(n*sizeof(float_t));
    
        for(idx_t i = 0; i < n; ++i)
        {
            r[i] = 0.5 * log(dp_info[i].ngbh.data[2].value/dp_info[i].ngbh.data[1].value);
            Pemp[i] = -log(1 - (float_t)(i + 1)/(float_t)n);
        }
        qsort(r,n,sizeof(float_t),cmp_float_t);
    
        idx_t Neff = (idx_t)(n*fraction);
    
        float_t d = mEst2(r,Pemp,Neff); 
        free(r);
        free(Pemp);
    
    	if(verbose)
    	{
    		clock_gettime(CLOCK_MONOTONIC, &finish_tot);
    		elapsed_tot = (finish_tot.tv_sec - start_tot.tv_sec);
    		elapsed_tot += (finish_tot.tv_nsec - start_tot.tv_nsec) / 1000000000.0;
    		printf("\tID value: %.6lf\n", d);
    		printf("\tTotal time: %.3lfs\n\n", elapsed_tot);
    	}
    
        float_t rd = 0;
        MPI_Allreduce(&d, &rd, 1, MPI_MY_FLOAT, MPI_SUM, ctx -> mpi_communicator);
        rd = rd / ctx -> world_size;
    
        return rd;
    
    }
    
    int binary_search_on_idxs(idx_t* idxs, idx_t key, int n)
    {
        #define LEFT  1
        #define RIGHT 0
    
        int l = 0;
        int r = n - 1;
        int center = (r - l)/2;
        while(idxs[center] != key && l < r)
        {
            int lr = key < idxs[center];
            /* if key < place */
            switch (lr)
            {
                case LEFT:
                    {
                        l = l;
                        r = center - 1;
                        center = l + (r - l) / 2;
                    }
                    break;
    
                case RIGHT:
                    {
                        l = center + 1;
                        r = r;
                        center = l + (r - l) / 2;
                    }
                    break;
    
                default:
                    break;
            }
    
        }
         
        return center;
    
        #undef LEFT
        #undef RIGHT
    }
    
    datapoint_info_t find_possibly_halo_datapoint(global_context_t* ctx, idx_t idx)
    {
        int owner = foreign_owner(ctx, idx);
        /* find if datapoint is halo or not */
        if(owner == ctx -> mpi_rank)
        {
            idx_t i = idx - ctx -> idx_start;
            return ctx -> local_datapoints[i];
        }
        else
        {
            datapoint_info_t* halo_dp = ctx -> halo_datapoints[owner]; 
            idx_t* halo_idxs = ctx -> idx_halo_points_recv[owner];
            int n = ctx -> n_halo_points_recv[owner];
            int i = binary_search_on_idxs(halo_idxs, idx, n);
            
            if( idx != halo_dp[i].ngbh.data[0].array_idx)
            // if( idx != halo_idxs[i])
            {
                printf("Osti %lu\n", idx);
            }
            return halo_dp[i];         
        }                 
    }
    
    void compute_density_kstarnn(global_context_t* ctx, const float_t d, int verbose){
    
        /*
         * Point density computation:                       
         * args:                                            
         * - paricles: array of structs                   
         * - d       : intrinsic dimension of the dataset 
         * - points  : number of points in the dataset    
         */
    
        struct timespec start_tot, finish_tot;
        double elapsed_tot;
    
        datapoint_info_t* local_datapoints = ctx -> local_datapoints;
    
    	if(verbose)
    	{
    		printf("Density and k* estimation:\n");
    		clock_gettime(CLOCK_MONOTONIC, &start_tot);
    	}
    
        idx_t kMAX = ctx -> local_datapoints[0].ngbh.N - 1;   
    
        float_t omega = 0.;  
        if(sizeof(float_t) == sizeof(float)){ omega = powf(PI_F,d/2)/tgammaf(d/2.0f + 1.0f);}  
        else{omega = pow(M_PI,d/2.)/tgamma(d/2.0 + 1.0);}
    
        #pragma omp parallel for
        for(idx_t i = 0; i < ctx -> local_n_points; ++i)
        {
    
            idx_t j = 4;
            idx_t k;
            float_t dL  = 0.;
            float_t vvi = 0.;
    		float_t vvj = 0.;
    		float_t vp  = 0.;
            while(j < kMAX  && dL < DTHR)
            {
                idx_t ksel = j - 1;
                vvi = omega * pow(local_datapoints[i].ngbh.data[ksel].value,d/2.);
    
                idx_t jj = local_datapoints[i].ngbh.data[j].array_idx;
    
                /* 
                 * note jj can be an halo point 
                 * need to search maybe for it in foreign nodes
                 * */
    
                datapoint_info_t tmp_dp = find_possibly_halo_datapoint(ctx, jj);
    
                vvj = omega * pow(tmp_dp.ngbh.data[ksel].value,d/2.);
    
                /* TO REMOVE
                if(local_datapoints[i].array_idx == 17734) printf("%lu ksel i %lu j %lu tmp_dp %lu di %lf fj %lf vvi %lf vvj %lf\n", ksel, i, jj, tmp_dp.array_idx, 
                            sqrt(local_datapoints[i].ngbh.data[ksel].value), sqrt(tmp_dp.ngbh.data[ksel].value), vvi, vvj);
                */
    
                vp = (vvi + vvj)*(vvi + vvj);
                dL = -2.0 * ksel * log(4.*vvi*vvj/vp);
                j = j + 1;
            }
            if(j == kMAX)
            {
                k = j - 1;
                vvi = omega * pow(ctx -> local_datapoints[i].ngbh.data[k].value,d/2.);
            }
            else
            {
                k = j - 2;
            }
            local_datapoints[i].kstar = k;
            local_datapoints[i].log_rho = log((float_t)(k)/vvi/((float_t)(ctx -> n_points)));
            //dp_info[i].log_rho = log((float_t)(k)) - log(vvi) -log((float_t)(points));
            local_datapoints[i].log_rho_err =   1.0/sqrt((float_t)k); //(float_t)(-Q_rsqrt((float)k));
            local_datapoints[i].g = local_datapoints[i].log_rho - local_datapoints[i].log_rho_err;
        }
    
    	if(verbose)
    	{
    		clock_gettime(CLOCK_MONOTONIC, &finish_tot);
    		elapsed_tot = (finish_tot.tv_sec - start_tot.tv_sec);
    		elapsed_tot += (finish_tot.tv_nsec - start_tot.tv_nsec) / 1000000000.0;
    		printf("\tTotal time: %.3lfs\n\n", elapsed_tot);
    	}
    
        #if defined(WRITE_DENSITY)
            /* densities */
            float_t* den = (float_t*)malloc(ctx -> local_n_points * sizeof(float_t));
            for(int i = 0; i < ctx -> local_n_points; ++i) den[i] = ctx -> local_datapoints[i].log_rho;
    
            ordered_buffer_to_file(ctx, den, sizeof(float_t), ctx -> local_n_points, "bb/ordered_density.npy");
            ordered_data_to_file(ctx);
            free(den);
        #endif
        return;
    
    
    }
    
    void simulate_master_read_and_scatter(int dims, size_t n, global_context_t *ctx) 
    {
        float_t *data;
        TIME_DEF
        double elapsed_time;
    
        if (ctx->mpi_rank == 0) 
        {
            //data = read_data_file(ctx, "../norm_data/50_blobs_more_var.npy", MY_TRUE);
            //ctx->dims = 2;
            //data = read_data_file(ctx, "../norm_data/50_blobs.npy", MY_TRUE);
            // std_g0163178_Me14_091_0000
        
            // 190M points
            // std_g2980844_091_0000
            // data = read_data_file(ctx,"../norm_data/std_g2980844_091_0000",MY_TRUE);
            
            /* 1M points ca.*/
            data = read_data_file(ctx,"../norm_data/std_LR_091_0001",MY_TRUE);
    
            /* 8M points */
            
            // data = read_data_file(ctx,"../norm_data/std_g0144846_Me14_091_0001",MY_TRUE);
    
            //88M 
            //data = read_data_file(ctx,"../norm_data/std_g5503149_091_0000",MY_TRUE);
    
            //
            //34 M
            // data = read_data_file(ctx,"../norm_data/std_g1212639_091_0001",MY_TRUE);
            ctx->dims = 5;
    
            //ctx -> n_points = 48*5*2000;
            ctx->n_points = ctx->n_points / ctx->dims;
            //ctx->n_points = (ctx->n_points * 0.1) / 10;
            // ctx -> n_points = ctx -> world_size * 1000;
    
            //ctx -> n_points = 10000000 * ctx -> world_size;
            //generate_random_matrix(&data, ctx -> dims, ctx -> n_points, ctx);
            //mpi_printf(ctx, "Read %lu points in %u dims\n", ctx->n_points, ctx->dims);
        }
        //MPI_DB_PRINT("[MASTER] Reading file and scattering\n");
        
        /* communicate the total number of points*/
        MPI_Bcast(&(ctx->dims), 1, MPI_UINT32_T, 0, ctx->mpi_communicator);
        MPI_Bcast(&(ctx->n_points), 1, MPI_UINT64_T, 0, ctx->mpi_communicator);
    
        /* compute the number of elements to recieve for each processor */
        int *send_counts = (int *)malloc(ctx->world_size * sizeof(int));
        int *displacements = (int *)malloc(ctx->world_size * sizeof(int));
    
        displacements[0] = 0;
        send_counts[0] = ctx->n_points / ctx->world_size;
        send_counts[0] += (ctx->n_points % ctx->world_size) > 0 ? 1 : 0;
        send_counts[0] = send_counts[0] * ctx->dims;
    
        for (int p = 1; p < ctx->world_size; ++p) 
        {
            send_counts[p] = (ctx->n_points / ctx->world_size);
            send_counts[p] += (ctx->n_points % ctx->world_size) > p ? 1 : 0;
            send_counts[p] = send_counts[p] * ctx->dims;
            displacements[p] = displacements[p - 1] + send_counts[p - 1];
        }
    
    
        ctx->local_n_points = send_counts[ctx->mpi_rank] / ctx->dims;
    
        float_t *pvt_data = (float_t *)malloc(send_counts[ctx->mpi_rank] * sizeof(float_t));
    
        MPI_Scatterv(data, send_counts, displacements, MPI_MY_FLOAT, pvt_data, send_counts[ctx->mpi_rank], MPI_MY_FLOAT, 0, ctx->mpi_communicator);
    
        ctx->local_data = pvt_data;
    
        int k_local = 20;
        int k_global = 20;
    
        uint64_t *global_bin_counts_int = (uint64_t *)malloc(k_global * sizeof(uint64_t));
    
        pointset_t original_ps;
        original_ps.data = ctx->local_data;
        original_ps.dims = ctx->dims;
        original_ps.n_points = ctx->local_n_points;
        original_ps.lb_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
        original_ps.ub_box = (float_t*)malloc(ctx -> dims * sizeof(float_t));
    
        float_t tol = 0.002;
    
        top_kdtree_t tree;
        TIME_START;
        top_tree_init(ctx, &tree);
        elapsed_time = TIME_STOP;
        LOG_WRITE("Initializing global kdtree", elapsed_time);
    
        TIME_START;
        build_top_kdtree(ctx, &original_ps, &tree, k_global, tol);
        exchange_points(ctx, &tree);
        elapsed_time = TIME_STOP;
        LOG_WRITE("Top kdtree build and domain decomposition", elapsed_time);
        //test_the_idea(ctx);
    
        TIME_START;
        kdtree_v2 local_tree;
        kdtree_v2_init( &local_tree, ctx -> local_data, ctx -> local_n_points, (unsigned int)ctx -> dims);
        int k = 300;
        //int k = 30;
    
        datapoint_info_t* dp_info = (datapoint_info_t*)malloc(ctx -> local_n_points * sizeof(datapoint_info_t));            
        /* initialize, to cope with valgrind */
        for(uint64_t i = 0; i < ctx -> local_n_points; ++i)
        {
            dp_info[i].ngbh.data = NULL;
            dp_info[i].ngbh.N = 0;
            dp_info[i].ngbh.count = 0;
            dp_info[i].g = 0.f;
            dp_info[i].log_rho = 0.f;
            dp_info[i].log_rho_c = 0.f;
            dp_info[i].log_rho_err = 0.f;
            dp_info[i].array_idx = -1;
            dp_info[i].kstar = -1;
            dp_info[i].is_center = -1;
            dp_info[i].cluster_idx = -1;
        }
    
        build_local_tree(ctx, &local_tree);
        elapsed_time = TIME_STOP;
        LOG_WRITE("Local trees init and build", elapsed_time);
    
        TIME_START;
        MPI_DB_PRINT("----- Performing ngbh search -----\n");
        MPI_Barrier(ctx -> mpi_communicator);
    
        mpi_ngbh_search(ctx, dp_info, &tree, &local_tree, ctx -> local_data, k);
    
        MPI_Barrier(ctx -> mpi_communicator);
        elapsed_time = TIME_STOP;
        LOG_WRITE("Total time for all knn search", elapsed_time)
    
    
        TIME_START;
    
        datapoint_info_t** foreign_dp_info = (datapoint_info_t**)malloc(ctx -> world_size * sizeof(datapoint_info_t*));
        find_foreign_nodes(ctx, dp_info, foreign_dp_info);
        elapsed_time = TIME_STOP;
        LOG_WRITE("Finding points to request the ngbh", elapsed_time)
    
        TIME_START;
        //float_t id = id_estimate(ctx, dp_info, ctx -> local_n_points, 0.9, MY_FALSE);
        float_t id = compute_ID_two_NN_ML(ctx, dp_info, ctx -> local_n_points, MY_FALSE);
        elapsed_time = TIME_STOP;
        //id = 3.920865231328582;
        //id = 4.008350298212649;
        //id = 4.;
        LOG_WRITE("ID estimate", elapsed_time)
    
        MPI_DB_PRINT("ID %lf \n",id);
    
        TIME_START;
        compute_density_kstarnn(ctx, id, MY_FALSE);
        elapsed_time = TIME_STOP;
        LOG_WRITE("Density estimate", elapsed_time)
    
        
    
        /* find density */ 
    
    
    
    
        #if defined (WRITE_NGBH)
            ordered_data_to_file(ctx);
        #endif
    
        /*
        for(int i = 0; i < ctx -> local_n_points; ++i)
        {
            free(dp_info[i].ngbh.data);
        }
    
        for(int i = 0; i < ctx -> world_size; ++i)
        {
            for(int j = 0; j < ctx -> n_halo_points_recv[i]; ++j)
            {
                free(foreign_dp_info[i][j].ngbh.data);
            }
            free(foreign_dp_info[i]);
        }
    
        free(foreign_dp_info);
        */
        
        
        top_tree_free(ctx, &tree);
        kdtree_v2_free(&local_tree);
    
        free(send_counts);
        free(displacements);
        //free(dp_info);
    
        if (ctx->mpi_rank == 0) free(data);
    
        original_ps.data = NULL;
        free_pointset(&original_ps);
        free(global_bin_counts_int);
    
    }