import datetime
import json
import psycopg2
import psycopg2.extras
import sys

from node import Node


class DbConnector(object):

    def __init__(self, user, password, host, port, dbname):
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.dbname = dbname

    def connect(self):
        try:
            self.conn = psycopg2.connect(user = self.user,
                                         password = self.password,
                                         host = self.host,
                                         port = self.port,
                                         database = self.dbname)
        except(Exception, psycopg2.Error) as error :
            sys.exit(f"Error while connecting to PostgreSQL: {error}")
        self.cursor = self.conn.cursor(cursor_factory = psycopg2.extras.RealDictCursor)

    def disconnect(self):
        if self.conn:
            self.cursor.close()
            self.conn.close()

    """
    Getters
    """

    ### Node

    def nodeExists(self, node):
        """Checks if a VOSpace node already exists. Returns a boolean."""
        if self.conn:
            nodeVOSPath = node.parentPath + '/' + node.name
            self.cursor.execute("SELECT * FROM node_vos_path WHERE vos_path = %s;", (nodeVOSPath,))
            result = self.cursor.fetchall()
            if result:
                return True
            else:
                return False

    def getOSPath(self, vospacePath):
        """Returns a list containing full path, storage type and username for a VOSpace path."""
        if self.conn:
            self.cursor.execute("""
                SELECT storage_type, base_path, user_name, tstamp_wrapper_dir, os_path
                FROM node_path p
                JOIN node n ON p.node_id = n.node_id
                JOIN location l ON n.location_id = l.location_id
                JOIN storage s ON s.storage_id = l.storage_src_id
                JOIN users u ON u.rap_id = n.owner_id
                WHERE p.vos_path = %s;
                """,
                (vospacePath,))
            result = self.cursor.fetchall()
            storageType = result[0]["storage_type"]
            basePath = result[0]["base_path"]
            userName = result[0]["user_name"]
            tstampWrappedDir = result[0]["tstamp_wrapper_dir"]
            osPath = result[0]["os_path"]
            if tstampWrappedDir is None:
                fullPath = basePath + "/" + userName + osPath
            else:
                fullPath = basePath + "/" + userName + "/" + tstampWrappedDir + osPath
            return [ fullPath, storageType, userName, osPath ]

    def getVOSpacePathList(self, vospacePath):
        """Returns the list of VOSpace paths carried by a VOSpace node, according to the node VOSpace path."""
        if self.conn:
            self.cursor.execute("""
                SELECT op.vos_path
                FROM node_vos_path vp
                JOIN list_of_files l ON l.list_node_id = vp.node_id
                JOIN node_path op ON op.node_id = l.node_id
                WHERE vp.vos_path = %s;
                """,
                (vospacePath,))
            results = self.cursor.fetchall()
            vospacePathList = []
            for el in results:
                vospacePathList.append(el["vos_path"])
            return vospacePathList

    ### Job

    def getJob(self, jobId):
        """Returns a JSON object containing job information, according to the job id."""
        if self.conn:
            self.cursor.execute("SELECT * FROM job WHERE job_id = %s;", (jobId,))
            out = open("db_connector_log.txt", "a")
            result = self.cursor.fetchall()
            #out.write(f"result: {result}\n\n")
            #out.close()
            if not result:
                return json.loads('{ "error": "JOB_NOT_FOUND" }')
            else:
                job = dict()
                for idx in result[0]:
                    oldIdx = idx
                    idxTokens = idx.split('_')
                    idx = idxTokens[0] + ''.join(token.title() for token in idxTokens[1:])
                    job[idx] = result[0][oldIdx]
                    el = job[idx]                    
                    if isinstance(el, datetime.datetime):
                        job[idx] = el.isoformat()
                out.write(f"job: {job}\n\n")
                out.close()
                return job
    
    ### Users
    
    def userExists(self, username):
        """Checks if a user already exists. Returns a boolean."""
        if self.conn:
            self.cursor.execute("SELECT * FROM users WHERE user_name = %s;", (username,))
            result = self.cursor.fetchall()
            if result:
                return True
            else:
                return False

    def getRapId(self, username):
        """Returns the RAP id for a given user name."""
        if self.conn:
            self.cursor.execute("SELECT rap_id FROM users WHERE user_name = %s;", (username,))
            return self.cursor.fetchall()[0]["rap_id"]

    def getUserName(self, rapId):
        """Returns the user name for a given RAP id."""
        if self.conn:
            self.cursor.execute("SELECT user_name FROM users WHERE rap_id = %s;", (rapId,))
            return self.cursor.fetchall()[0]["user_name"]
  
    ### Storage
  
    def storageBasePathIsValid(self, path):
        """Checks if the base path of a physical path is valid. If true, returns the base path, else returns 'False'."""
        if self.conn:
            self.cursor.execute("""
                SELECT base_path 
                FROM storage 
                WHERE position(base_path in cast(%s as varchar)) > 0;
                """, 
                (path,))
            result = self.cursor.fetchall()
            if result:
                return result[0]["base_path"]
            else:
                return False    
        
    def getStorageBasePath(self, storageId):
        """Returns the storage base path for a give storage id"""
        if self.conn:
            self.cursor.execute("SELECT base_path FROM storage WHERE storage_id = %s;", (storageId,))
            return self.cursor.fetchall()[0]["base_path"]

    def getStorageList(self):
        """Returns the full storage base list. Local storage points are excluded by default."""
        if self.conn:
            self.cursor.execute("SELECT * FROM storage WHERE storage_type <> 'local';")
            result = self.cursor.fetchall()
            return result
    
    def getStorageListByType(self, storageType):
        """Returns a list of storage locations for a given storage type"""
        if self.conn:
            self.cursor.execute("SELECT * FROM storage WHERE storage_type = %s;", (storageType,))
            return self.cursor.fetchall()
        
    def getStorageId(self, basePath):
        """Returns the storage id for a given storage base path, if any. Otherwise it returns 'False'"""
        if self.conn:
            self.cursor.execute("SELECT storage_id FROM storage WHERE base_path = %s;", (basePath,))
            result = self.cursor.fetchall()
            if result:
                return result[0]["storage_id"]
            else:
                return False

    ### Location

    def getLocationId(self, destStorageId):
        """Returns the location id according to the storage id of the destination"""
        if self.conn:
            self.cursor.execute("SELECT location_id FROM location WHERE storage_src_id = %s;", (destStorageId,))
            return self.cursor.fetchall()[0]["location_id"]


    """
    Setters
    """

    ### Job

    def insertJob(self, jobObj):
        """Inserts/updates a job object."""
        if self.conn:
            self.cursor.execute("""
                INSERT INTO job(job_id, 
                                owner_id, 
                                job_type, 
                                phase, 
                                start_time, 
                                end_time, 
                                job_info, 
                                results)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                ON CONFLICT (job_id)
                DO UPDATE SET
                (owner_id, 
                 job_type, phase, 
                 start_time, 
                 end_time, 
                 job_info, 
                 results)
                = (EXCLUDED.owner_id, 
                   EXCLUDED.job_type, 
                   EXCLUDED.phase, 
                   EXCLUDED.start_time, 
                   EXCLUDED.end_time, 
                   EXCLUDED.job_info, 
                   EXCLUDED.results);
                """,
                (jobObj.jobId,
                 jobObj.ownerId,
                 jobObj.type,
                 jobObj.phase,
                 jobObj.startTime,
                 jobObj.endTime,
                 json.dumps(jobObj.jobInfo),
                 json.dumps(jobObj.results),))
            self.conn.commit()

    def setPhase(self, jobId, phase):
        """Sets the job 'phase' parameter."""
        if self.conn:
            self.cursor.execute("""
                UPDATE job SET phase = %s
                WHERE job_id = %s;
                """,
                (phase, jobId,))
            self.conn.commit()

    def setResults(self, jobId, results):
        """Sets the job 'results' parameter."""
        if self.conn:
            self.cursor.execute("""
                UPDATE job SET results = %s
                WHERE job_id = %s;
                """,
                (json.dumps(results),
                 jobId,))
            self.conn.commit()
            
    ### Node

    def insertNode(self, node):
        """Inserts a VOSpace node."""
        if self.conn:
            out = open("db_connector_log.txt", "a")
            out.write(f"parentOSPath: {node.parentPath}\n")
            out.write(f"name: {node.name}\n")
            self.cursor.execute("""
                SELECT path FROM node n
                JOIN node_vos_path o ON n.node_id = o.node_id
                WHERE vos_path = %s;
                """,
                (node.parentPath,))
            result = self.cursor.fetchall()
            for i in result:
                out.write(f"queryResult: {i}\n")
            #parentLtreePath = self.cursor.fetchone()[0]
            parentLtreePath = result[0]["path"]
            parentLtreeRelativePath = ""
            if "." in parentLtreePath:
                parentLtreeRelativePath = ".".join(parentLtreePath.strip(".").split('.')[1:])
            out.write(f"parentLtreeRelativePath: {parentLtreeRelativePath}\n")
            out.write(f"parentLtreePath: {parentLtreePath}\n")
            out.write(f"parentPath: {node.parentPath}\n\n")
            out.close()
            #print(f"parentLtreePath: {parentLtreePath}, type: {type(parentLtreePath)}")
            self.cursor.execute("""
                INSERT INTO node(parent_path, 
                                 parent_relative_path, 
                                 name, 
                                 tstamp_wrapper_dir, 
                                 type, 
                                 location_id, 
                                 busy_state, 
                                 owner_id, 
                                 creator_id,
                                 content_length,
                                 content_md5)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
                """,
                (parentLtreePath,
                 parentLtreeRelativePath,
                 node.name,
                 node.wrapperDir,
                 node.type,
                 node.locationId,
                 node.busyState,
                 node.ownerID,
                 node.creatorID,
                 node.contentLength,
                 node.contentMD5,))
            self.conn.commit()    
            
    def deleteTmpDataNode(self, vospacePath):
        """Deletes a temporary VOSpace data node"""
        if self.conn:
            self.cursor.execute("""
                WITH deleted AS (
                DELETE FROM list_of_files 
                WHERE list_node_id =
                (SELECT node_id FROM node_vos_path 
                 WHERE vos_path = %s)
                 RETURNING list_node_id
                ) DELETE FROM node 
                  WHERE node_id = 
                  (SELECT DISTINCT(list_node_id) 
                   FROM deleted);
                """,
                (vospacePath,))
            self.conn.commit()

    def setAsyncTrans(self, nodeVOSPath, value):
        """Sets the 'async_trans' flag for a VOSpace node."""
        if self.conn:
            self.cursor.execute("""
                UPDATE node SET async_trans = %s                                                                                        
                WHERE path <@
                (SELECT path
                 FROM node_path p
                 JOIN node n ON p.node_id = n.node_id
                 WHERE p.vos_path = %s);
                """,
                (value, nodeVOSPath,))
            self.conn.commit()
    
    def setSticky(self, nodeVOSPath, value):
        """Sets the 'sticky' flag for a VOSpace node."""
        if self.conn:
            self.cursor.execute("""
                UPDATE node SET sticky = %s                                                                                        
                WHERE path <@
                (SELECT path
                 FROM node_path p
                 JOIN node n ON p.node_id = n.node_id
                 WHERE p.vos_path = %s);
                """,
                (value, nodeVOSPath,))
            self.conn.commit()
    
    def setBusyState(self, nodeVOSPath, value):
        """Sets the 'busy_state' flag for a VOSpace node."""
        if self.conn:
            self.cursor.execute("""
                UPDATE node SET busy_state = %s                                                                                        
                WHERE path <@
                (SELECT path
                 FROM node_path p
                 JOIN node n ON p.node_id = n.node_id
                 WHERE p.vos_path = %s);
                """,
                (value, nodeVOSPath,))
            self.conn.commit()
            
    ### Storage

    def insertStorage(self, storageType, basePath, hostname):
        if self.conn:
            if not self.getStorageId(basePath):
                self.cursor.execute("""
                    INSERT INTO storage(storage_type, 
                                        base_path, 
                                        hostname)
                    VALUES (%s, %s, %s)
                    RETURNING storage_id;
                    """,
                    (storageType,
                     basePath,
                     hostname,))

                storageSrcId = self.cursor.fetchall()[0]["storage_id"]
                
                if storageType == "cold" or storageType == "hot":
                    self.cursor.execute("""
                        SELECT storage_id
                        FROM storage
                        WHERE storage_type = 'local'
                        AND base_path = '/home'
                        AND hostname = 'localhost';
                        """)
                    storageDestId = self.cursor.fetchall()[0]["storage_id"]
                    locationType = "async"                                       
                else:
                    storageDestId = storageSrcId
                    locationType = "portal"
                                    
                self.cursor.execute("""
                    INSERT INTO location(location_type,
                                         storage_src_id,
                                         storage_dest_id)
                    VALUES (%s, %s, %s);
                    """,
                    (locationType,
                     storageSrcId,
                     storageDestId,))
                
                self.conn.commit()
                return True
            else:
                return False
                
