#!/usr/bin/env python
import os
import sys
import argparse
import warnings
import pvl
import math
import pyproj

import numpy as np
import pandas as pd

from plio.io.io_bae import read_atf, read_gpf, read_ipf
import plio.io.io_controlnetwork as cn
import plio.io.isis_serial_number as sn
from plio.utils.utils import find_in_dict, split_all_ext

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('at_file', help='Path to the .atf file for a project.')
    parser.add_argument('cub_list', help='Path to a list file containing paths to the associated\
                                                                 Isis cubes.')
    parser.add_argument('target_name', help='Name of the target body used in the control net')
    parser.add_argument('--outpath', help='Directory for the control network to be output to.')

    return parser.parse_args()

def line_sample_size(record, path):
    """
    Converts columns l. and s. to sample size, line size, and generates an
    image index

    Parameters
    ----------
    record : object
             Pandas series object

    path : str
           Path to the associated sup files for a socet project

    Returns
    -------
    : list
      A list of sample_size, line_size, and img_index
    """
    with open(os.path.join(path, record['ipf_file'] + '.sup')) as f:
        for i, line in enumerate(f):
            if i == 2:
                img_index = line.split('\\')
                img_index = img_index[-1].strip()
                img_index = img_index.split('.')[0]

            if i == 3:
                line_size = line.split(' ')
                line_size = line_size[-1].strip()
                assert int(line_size) > 0, "Line number {} from {} is a negative number: Invalid Data".format(line_size, record['ipf_file'])

            if i == 4:
                sample_size = line.split(' ')
                sample_size = sample_size[-1].strip()
                assert int(sample_size) > 0, "Sample number {} from {} is a negative number: Invalid Data".format(sample_size, record['ipf_file'])
                break


        line_size = int(line_size)/2.0 + record['l.'] + 1
        sample_size = int(sample_size)/2.0 + record['s.'] + 1
        return sample_size, line_size, img_index

def get_axis(file):
    """
    Gets eRadius and pRadius from a .prj file

    Parameters
    ----------
    file : str
           file with path to a given socet project file

    Returns
    -------
    : list
      A list of the eRadius and pRadius of the project file
    """
    with open(file) as f:
        from collections import defaultdict

        files = defaultdict(list)

        for line in f:

            ext = line.strip().split(' ')
            files[ext[0]].append(ext[-1])

        eRadius = float(files['A_EARTH'][0])
        pRadius = eRadius * math.sqrt(1 - (float(files['E_EARTH'][0]) ** 2))

        return eRadius, pRadius

def reproject(record, semi_major, semi_minor, source_proj, dest_proj, **kwargs):
    """
    Thin wrapper around PyProj's Transform() function to transform 1 or more three-dimensional
    point from one coordinate system to another. If converting between Cartesian
    body-centered body-fixed (BCBF) coordinates and Longitude/Latitude/Altitude coordinates,
    the values input for semi-major and semi-minor axes determine whether latitudes are
    planetographic or planetocentric and determine the shape of the datum for altitudes.
    If semi_major == semi_minor, then latitudes are interpreted/created as planetocentric
    and altitudes are interpreted/created as referenced to a spherical datum.
    If semi_major != semi_minor, then latitudes are interpreted/created as planetographic
    and altitudes are interpreted/created as referenced to an ellipsoidal datum.

    Parameters
    ----------
    record : object
             Pandas series object

    semi_major : float
                 Radius from the center of the body to the equater

    semi_minor : float
                 Radius from the pole to the center of mass

    source_proj : str
                         Pyproj string that defines a projection space ie. 'geocent'

    dest_proj : str
                      Pyproj string that defines a project space ie. 'latlon'

    Returns
    -------
    : list
      Transformed coordinates as y, x, z

    """
    source_pyproj = pyproj.Proj(proj = source_proj, a = semi_major, b = semi_minor)
    dest_pyproj = pyproj.Proj(proj = dest_proj, a = semi_major, b = semi_minor)

    y, x, z = pyproj.transform(source_pyproj, dest_pyproj, record[0], record[1], record[2], **kwargs)
    return y, x, z

# TODO: Does isis cnet need a convariance matrix for sigmas? Even with a static matrix of 1,1,1,1
def compute_sigma_covariance_matrix(lat, lon, rad, latsigma, lonsigma, radsigma, semimajor_axis):

    """
    Given geospatial coordinates, desired accuracy sigmas, and an equitorial radius, compute a 2x3
    sigma covariange matrix.
    Parameters
    ----------
    lat : float
          A point's latitude in degrees

    lon : float
          A point's longitude in degrees

    rad : float
          The radius (z-value) of the point in meters

    latsigma : float
               The desired latitude accuracy in meters (Default 10.0)

    lonsigma : float
               The desired longitude accuracy in meters (Default 10.0)

    radsigma : float
               The desired radius accuracy in meters (Defualt: 15.0)

    semimajor_axis : float
                     The semi-major or equitorial radius in meters (Default: 1737400.0 - Moon)
    Returns
    -------
    rectcov : ndarray
              (2,3) covariance matrix
    """
    lat = math.radians(lat)
    lon = math.radians(lon)

    # SetSphericalSigmasDistance
    scaled_lat_sigma = latsigma / semimajor_axis

    # This is specific to each lon.
    scaled_lon_sigma = lonsigma * math.cos(lat) / semimajor_axis

    # SetSphericalSigmas
    cov = np.eye(3,3)
    cov[0,0] = math.radians(scaled_lat_sigma) ** 2
    cov[1,1] = math.radians(scaled_lon_sigma) ** 2
    cov[2,2] = radsigma ** 2

    # Approximate the Jacobian
    j = np.zeros((3,3))
    cosphi = math.cos(lat)
    sinphi = math.sin(lat)
    cos_lmbda = math.cos(lon)
    sin_lmbda = math.sin(lon)
    rcosphi = rad * cosphi
    rsinphi = rad * sinphi
    j[0,0] = -rsinphi * cos_lmbda
    j[0,1] = -rcosphi * sin_lmbda
    j[0,2] = cosphi * cos_lmbda
    j[1,0] = -rsinphi * sin_lmbda
    j[1,1] = rcosphi * cos_lmbda
    j[1,2] = cosphi * sin_lmbda
    j[2,0] = rcosphi
    j[2,1] = 0.
    j[2,2] = sinphi
    mat = j.dot(cov)
    mat = mat.dot(j.T)
    rectcov = np.zeros((2,3))
    rectcov[0,0] = mat[0,0]
    rectcov[0,1] = mat[0,1]
    rectcov[0,2] = mat[0,2]
    rectcov[1,0] = mat[1,1]
    rectcov[1,1] = mat[1,2]
    rectcov[1,2] = mat[2,2]

    return rectcov

def compute_cov_matrix(record, semimajor_axis):
    cov_matrix = compute_sigma_covariance_matrix(record['lat_Y_North'], record['long_X_East'], record['ht'], record['sig0'], record['sig1'], record['sig2'], semimajor_axis)
    return cov_matrix.ravel().tolist()

def stat_toggle(record):
    if record['stat'] == 0:
        return True
    else:
        return False

def known(record):
    """
    Converts the known field from a socet dataframe into the
    isis point_type column

    Parameters
    ----------
    record : object
             Pandas series object

    Returns
    -------
    : str
      String representation of a known field
    """

    lookup = {0: 'Free',
              1: 'Constrained',
              2: 'Constrained',
              3: 'Constrained'}
    return lookup[record['known']]

def apply_socet_transformations(atf_dict, df):
    """
    Takes a atf dictionary and a socet dataframe and applies the necessary
    transformations to convert that dataframe into a isis compatible
    dataframe

    Parameters
    ----------
    atf_dict : dict
               Dictionary containing information from an atf file

    df : object
         Pandas dataframe object

    """
    prj_file = os.path.join(atf_dict['PATH'], atf_dict['PROJECT'])

    eRadius, pRadius = get_axis(prj_file)

    # Convert longitude and latitude from radians to degrees
    df['long_X_East'] = df['long_X_East'].apply(np.degrees)
    df['lat_Y_North'] = df['lat_Y_North'].apply(np.degrees)

    lla = np.array([[df['long_X_East']], [df['lat_Y_North']], [df['ht']]])

    ecef = reproject(lla, semi_major = eRadius, semi_minor = pRadius,
                              source_proj = 'latlon', dest_proj = 'geocent')

    df['s.'], df['l.'], df['image_index'] = (zip(*df.apply(line_sample_size, path = atf_dict['PATH'], axis=1)))
    df['known'] = df.apply(known, axis=1)
    df['long_X_East'] = ecef[0][0]
    df['lat_Y_North'] = ecef[1][0]
    df['ht'] = ecef[2][0]
    df['aprioriCovar'] = df.apply(compute_cov_matrix, semimajor_axis = eRadius, axis=1)
    df['stat'] = df.apply(stat_toggle, axis=1)
    
def main(args):
    # Setup the at_file, path to cubes, and control network out path
    at_file = args.at_file

    with open(args.cub_list, 'r') as f:
        lines = f.readlines()
        cub_list = [cub.replace('\n', '') for cub in lines]

    cnet_out = os.path.split(os.path.splitext(at_file)[0])[1]

    if( args.outpath ):
        outpath = args.outpath
    else:
        outpath = os.path.split(at_file)[0]

    # Read in and setup the atf dict of information
    atf_dict = read_atf(at_file)

    # Get the gpf and ipf files using atf dict
    gpf_file = os.path.join(atf_dict['PATH'], atf_dict['GP_FILE']);
    ipf_list = [os.path.join(atf_dict['PATH'], i) for i in atf_dict['IMAGE_IPF']]

    # Read in the gpf file and ipf file(s) into seperate dataframes
    gpf_df = read_gpf(gpf_file)
    ipf_df = read_ipf(ipf_list)

    # Check for differences between point ids using each dataframes
    # point ids as a reference
    gpf_pt_idx = pd.Index(pd.unique(gpf_df['point_id']))
    ipf_pt_idx = pd.Index(pd.unique(ipf_df['pt_id']))

    point_diff = ipf_pt_idx.difference(gpf_pt_idx)

    if len(point_diff) != 0:
        warnings.warn("The following points found in ipf files missing from gpf file: \n\n{}. \
                      \n\nContinuing, but these points will be missing from the control network".format(list(point_diff)))

    # Merge the two dataframes on their point id columns
    socet_df = ipf_df.merge(gpf_df, left_on='pt_id', right_on='point_id')

    # Apply the transformations
    apply_socet_transformations(atf_dict, socet_df)

    # Define column remap for socet dataframe
    column_map = {'pt_id': 'id', 'l.': 'y', 's.': 'x',
                               'res_l': 'lineResidual', 'res_s': 'sampleResidual', 'known': 'Type',
                               'lat_Y_North': 'aprioriY', 'long_X_East': 'aprioriX', 'ht': 'aprioriZ',
                               'sig0': 'aprioriLatitudeSigma', 'sig1': 'aprioriLongitudeSigma', 'sig2': 'aprioriRadiusSigma',
                               'sig_l': 'linesigma', 'sig_s': 'samplesigma'}

    # Rename the columns using the column remap above
    socet_df.rename(columns = column_map, inplace=True)

    # Build a serial dict assuming the cubes will be named as the IPFs are
    serial_dict = {split_all_ext(os.path.split(i)[-1]): sn.generate_serial_number(i) for i in cub_list}

    # creates the control network
    cn.to_isis(os.path.join(outpath, cnet_out + '.net'), socet_df, serial_dict, targetname = args.target_name)

if __name__ == '__main__':
    main(parse_args())