#!/usr/bin/env python
# coding: utf-8

import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import NearestNeighbors

ndims = 5
k     = 500 
p     = 10 

with open("bb/top_nodes.csv","r") as f:
    l = f.readlines() 

def parse_lines(l,n_dims):
    ll = [line.split(",") for line in l]
    level = np.array([ int(line[0]) for line in ll])
    owner = np.array([ int(line[1]) for line in ll])
    split_dim = np.array([ int(line[2]) for line in ll])
    split_val = np.array([ float(line[3]) for line in ll])
    box_lb = np.array([ [float(el) for el in line[4:(4+n_dims)]] for line in ll])
    box_ub = np.array([ [float(el) for el in line[4 + n_dims:]] for line in ll])
    return level, owner, split_dim, split_val, box_lb, box_ub

def plot_boxes(x,d0,d1,owner, split_dim, split_val, box_lb, box_ub, ratio = 0.7):
    from matplotlib.patches import Rectangle
    fig, ax = plt.subplots(figsize = (12 * ratio,10 * ratio))
    ax.scatter(x[:,d0],x[:,d1], s = 0.1)
    procs = np.where(owner != -1)
    for p in procs[0]:
        lbx = box_lb[p,d0]
        ubx = box_ub[p,d0]
        lby = box_lb[p,d1]
        uby = box_ub[p,d1]
        bw  = ubx - lbx
        bh  = uby - lby
        col = (np.random.rand(),np.random.rand(),np.random.rand(),0.5)
        ax.add_patch(Rectangle((lbx,lby),bw,bh, facecolor = col, label = owner[p]))
    plt.legend(loc = "lower left")
        #ax.add_patch(Rectangle((lbx,lby),2,2, facecolor = (np.random.rand(),np.random.rand(),np.random.rand(),0.3)))

def plot_planes(x,d0,d1,owner, split_dim, split_val, box_lb, box_ub, ratio=0.7):
    from matplotlib.patches import Rectangle
    fig, ax = plt.subplots(figsize = (12 * ratio,10 * ratio))
    ax.scatter(x[:,d0],x[:,d1], s = 0.1)
    procs = np.where(owner == -1)[0]
    for p in procs:
        if split_dim[p] == d0:
            line_bounds = [box_lb[p,d1],box_ub[p,d1]]
            line_coord  = split_val[p] 
            #print("vline",split_dim[p],split_dim[p], line_bounds, line_coord)
            plt.vlines(line_coord, line_bounds[0], line_bounds[1], color = "y")
        elif split_dim[p] == d1:
            line_bounds = [box_lb[p,d0],box_ub[p,d0]]
            line_coord  = split_val[p] 
            #print("hline",split_dim[p],split_dim[p], line_bounds, line_coord)
            plt.hlines(line_coord, line_bounds[0], box_ub[p,d0], color = "y")
    plt.show()


if __name__ == "__main__":
    level, owner, split_dim, split_val, box_lb, box_ub = parse_lines(l,ndims)

    #x = np.fromfile("../../robavaria/50_blobs_more_var.npy", np.float32)
    print("Loading data file")
    x = np.fromfile("./bb/ordered_data.npy", np.float64)
    x = x.reshape((x.shape[0]//ndims,ndims))

    #plot_boxes(x,0,1,owner,split_dim,split_val,box_lb,box_ub)
    #plot_planes(x,0,1,owner,split_dim,split_val,box_lb,box_ub)

    print("Loading ngbh results")
    ngbh = []
    for pp in range(p):
        ngbh.append(np.fromfile(f"./bb/rank_{pp}.ngbh", dtype = [("value","f8"),("array_idx","u8")]))
    ngbh = np.concatenate(ngbh)

    print("Searching for neighbors")
    nn = NearestNeighbors(n_jobs=-1,n_neighbors=k)

    nn.fit(x)
    dist, idx = nn.kneighbors(x)

    idx_c = ngbh["array_idx"]
    idx_c.shape
    dist_c = ngbh["value"]


    idx_c = idx_c.reshape((len(idx_c)//k,k))
    dist_c = dist_c.reshape((len(dist_c)//k,k))

    same_dist = 0
    sd_el = []
    abs_errors = 0

    print("Check")
    for i in range(len(idx_c)):
        r1 = idx[i]
        r2 = idx_c[i]
        w = np.where(r1 != r2)
        if len(w[0]) > 0:
            d1 = dist[i,w[0][0]]
            d2 = dist[i,w[0][1]]
            #print(i, w[0])
            if not np.isclose(d1,d2):
                abs_errors += 1
                same_dist += 1
                #print("   Found error in ", w[0], d1, d2)
    print(f"Found {abs_errors} errors")

