#!/usr/bin/env python
#
# This file is part of vospace-transfer-service
# Copyright (C) 2021 Istituto Nazionale di Astrofisica
# SPDX-License-Identifier: GPL-3.0-or-later
#
#
# This class is responsible to retrieve data from a generic storage point.
#
# The operations performed are the briefly summarized here below:
# * obtain the storage type
# * create a list of files to be retrieved (list of dictionaries)
# * split the list in blocks of a fixed size
# * loop on each block and retrieve data
#   - if the storage type is 'cold' (tape) perform a recall operation
#     before the copy and a migrate operation after the copy
#   - check if data associated to a VOSpace node has been copied
#     every time a block is retrieved
#   - recursively update the 'async_trans' flag
# * cleanup
#
#


import datetime
import json
import os
import logging
import subprocess
import sys

from checksum import Checksum
from config import Config
from db_connector import DbConnector
from mailer import Mailer
from redis_log_handler import RedisLogHandler
from system_utils import SystemUtils
from tape_client import TapeClient
from task_executor import TaskExecutor


class RetrieveExecutor(TaskExecutor):

    def __init__(self):
        self.type = "retrieve_executor"
        self.systemUtils = SystemUtils()
        config = Config("/etc/vos_ts/vos_ts.conf")
        params = config.loadSection("transfer_node")
        self.storageRetrievePath = params["retrieve_path"]
        params = config.loadSection("transfer")
        self.maxBlockSize = self.systemUtils.convertSizeToBytes(params["block_size"])
        params = config.loadSection("scheduling")
        self.maxTerminatedJobs = params.getint("max_terminated_jobs")
        params = config.loadSection("mail")
        self.adminEmail = params["admin_email"]
        params = config.loadSection("logging")
        self.logger = logging.getLogger(__name__)
        logLevel = "logging." + params["log_level"]
        logFormat = params["log_format"]
        logFormatter = logging.Formatter(logFormat)
        self.logger.setLevel(eval(logLevel))
        redisLogHandler = RedisLogHandler()
        redisLogHandler.setFormatter(logFormatter)
        self.logger.addHandler(redisLogHandler)
        params = config.loadSection("file_catalog")
        self.dbConn = DbConnector(params["user"],
                                  params["password"],
                                  params["host"],
                                  params.getint("port"),
                                  params["db"],
                                  1,
                                  1,
                                  self.logger)
        params = config.loadSection("spectrum_archive")
        self.tapeClient = TapeClient(params["host"],
                                     params.getint("port"),
                                     params["user"],
                                     params["pkey_file_path"],
                                     self.logger)
        self.tapePool = None
        self.storageType = None
        self.jobObj = None
        self.jobId = None
        self.nodeList = []
        self.fileList = []
        self.destPathList = []
        self.numBlocks = 0
        self.procBlocks = 0
        self.totalSize = 0
        super(RetrieveExecutor, self).__init__()

    def  buildFileList(self):
        """
        Generates the list of all files to retrieve.
        This method returns 'True' on success, 'False' on failure.
        """
        try:
            try:
                self.jobObj.setPhase("EXECUTING")
                self.jobObj.setStartTime(datetime.datetime.now().isoformat())
                self.dbConn.insertJob(self.jobObj)
            except Exception:
                self.logger.exception("FATAL: unable to update the file catalog.")
                return False
            else:
                self.logger.info("Job phase updated to EXECUTING.")
            self.logger.info("Building the list of the files to be retrieved...")

            # debug block...
            #if os.path.exists("nodeList.txt"):
            #    os.remove("nodeList.txt")
            #nl = open("nodeList.txt", 'w')
            #for vospacePath in self.nodeList:
            #    nl.write(vospacePath + '\n')
            #nl.close()

            # Obtain the storage type
            try:
                #self.storageType = self.dbConn.getOSPath(self.nodeList[0])["storageType"]
                fileInfo = self.dbConn.getOSPath(self.nodeList[0])
                self.storageType = fileInfo["storageType"]
                self.tapePool = fileInfo["tapePool"]
            except Exception:
                self.logger.exception("FATAL: unable to obtain the storage type.")
                return False
            for vospacePath in self.nodeList:
                try:
                    nodeInfo = self.dbConn.getOSPath(vospacePath)
                except Exception:
                    self.logger.exception(f"FATAL: unable to obtain the OS path for the VOSpace path '{vospacePath}'.")
                    return False
                baseSrcPath = nodeInfo["baseSrcPath"]
                srcPath = nodeInfo["fullPath"]
                username = nodeInfo["username"]
                md5calc = Checksum()
                if os.path.isdir(srcPath) and not os.path.islink(srcPath):
                    for root, dirs, files in os.walk(srcPath, topdown = False):
                        #dirSize = os.stat(root).st_size
                        #self.totalSize += dirSize
                        for f in files:
                            fullPath = os.path.join(root, f)
                            if md5calc.fileIsValid(fullPath) and not os.path.islink(fullPath):
                                #fileSize = os.stat(fullPath).st_size
                                fileSize = os.path.getsize(fullPath)
                                fileInfo = {
                                               "baseSrcPath": baseSrcPath,
                                               "fullPath": fullPath,
                                               "username": username,
                                               "fileSize": fileSize,
                                               "vospaceRootParent": vospacePath
                                           }
                                self.totalSize += fileSize
                                self.fileList.append(fileInfo.copy())
                else:
                    if md5calc.fileIsValid(srcPath) and not os.path.islink(srcPath):
                        fileSize = nodeInfo["contentLength"]
                        fileInfo = {
                                       "baseSrcPath": baseSrcPath,
                                       "fullPath": srcPath,
                                       "username": username,
                                       "fileSize": fileSize,
                                       "vospaceRootParent": vospacePath
                                   }
                        self.totalSize += fileSize
                        self.fileList.append(fileInfo.copy())
            self.logger.info(f"Total size of files to retrieve: {self.totalSize} B")
            # debug block...
            #if os.path.exists("fileList.txt"):
            #    os.remove("fileList.txt")
            #fl = open("fileList.txt", 'w')
            #fl.write(json.dumps(self.fileList, indent = 4))
            #fl.close()
        except Exception:
            self.logger.exception("FATAL: something went wrong while building the list of the files to be retrieved.")
            return False
        else:
            return True

    def buildBlocks(self):
        """
        Algorithm to split data in blocks of a well known size.
        This method returns 'True' on success, 'False' on failure.
        """
        try:
            self.logger.info("Building the blocks data structure... ")
            if self.fileList:
                blockIdx = 0
                blockSize = 0
            for fileInfo in self.fileList:
                fileSize = fileInfo["fileSize"]
                #self.totalSize += fileSize
                # check if the file is larger than a block size
                if fileSize > self.maxBlockSize:
                    # if the current block is not empty, "close" it,
                    # otherwise use it
                    if blockSize > 0:
                        blockIdx += 1
                        fileInfo["blockIdx"] = blockIdx
                    else:
                        fileInfo["blockIdx"] = blockIdx
                    blockSize = self.maxBlockSize
                else:
                    # the file can be contained by a block, so check if
                    # the file size plus the current block fill is lower
                    # than the maximum block size
                    if blockSize + fileSize <= self.maxBlockSize:
                        # if so, add the file to the block and go ahead with
                        # the next one
                        fileInfo["blockIdx"] = blockIdx
                        blockSize += fileSize
                    else:
                        # if not, "close" the current block, add it to the block list,
                        # then create a new block, add the file to it and go ahead
                        # with the next one
                        blockIdx += 1
                        fileInfo["blockIdx"] = blockIdx
                        blockSize = fileSize
            if self.fileList:
                self.numBlocks = blockIdx + 1
                try:
                    self.dbConn.setTotalBlocks(self.jobId, self.numBlocks)
                except Exception:
                    self.logger.exception("FATAL: unable to set the total number of blocks in the database.")
                    return False

            # debug block...
            #print(f"numBlocks = {self.numBlocks}")
            #if os.path.exists("blocks.txt"):
            #    os.remove("blocks.txt")
            #fl = open("blocks.txt", 'w')
            #fl.write(json.dumps(self.fileList, indent = 4))
            #fl.close()
        except Exception:
            self.logger.exception("FATAL: something went wrong while building the blocks data structure.")
            return False
        else:
            return True

    def retrieveCompleted(self, vospacePath):
        """
        Returns 'True' if all data associated to 'vospacePath'
        has been copied, otherwise it returns 'False'.
        """
        return not any(vospacePath in f["vospaceRootParent"] for f in self.fileList)

    def retrieveData(self):
        """
        Retrieves data from a generic storage point (hot or cold).
        """
        try:
            self.logger.info("Starting data retrieval...")
            # Loop on blocks
            for blockIdx in range(self.numBlocks):
                blockFileList = [ f for f in self.fileList if f["blockIdx"] == blockIdx ]

                # Recall all files from tape library to tape frontend
                # if the storage type is 'cold'
                if self.storageType == "cold":
                    self.tapeClient.connect()
                    self.tapeClient.recall([ f["fullPath"] for f in blockFileList ], self.jobId)
                    self.tapeClient.disconnect()

                # Loop on files in a block
                for fileInfo in blockFileList:
                    srcPath = fileInfo["fullPath"]
                    username = fileInfo["username"]
                    baseSrcPath = fileInfo["baseSrcPath"]
                    osRelParentPath = os.path.dirname(srcPath)
                    osRelParentPath = osRelParentPath.replace(baseSrcPath, "")
                    if osRelParentPath != "/":
                        osRelParentPath += "/"
                    destDirPath = self.storageRetrievePath.replace("{username}", username) + osRelParentPath
                    os.makedirs(destDirPath, exist_ok = True)
                    if self.storageType == "cold":
                        sp = subprocess.run(["rsync", "-av", "--no-links", srcPath, destDirPath], capture_output = True)
                        if(sp.returncode or sp.stderr):
                            self.logger.error(f"FATAL: error during the copy process, returnCode = {sp.returncode}, stderr: {sp.stderr}")
                            return False
                    else:
                        destPath = destDirPath + os.path.basename(srcPath)
                        try:
                            if not os.path.islink(srcPath):
                                os.symlink(srcPath, destPath)
                        except Exception:
                            self.logger.error(f"FATAL: error while creating symlink for target '{srcPath}'")
                            return False

                # Remove files from file list at the end of the copy
                for fileInfo in blockFileList:
                    if fileInfo in self.fileList:
                        self.fileList.remove(fileInfo)

                # Check if the copy related to a certain VOSpace node
                # is completed and recursively update the 'async_trans'
                # flag
                for vospacePath in self.nodeList:
                    if self.retrieveCompleted(vospacePath):
                        try:
                            self.dbConn.setAsyncTrans(vospacePath, False)
                        except Exception:
                            self.logger.exception("FATAL: unable to update the file catalog.")
                            return False

                # Empty the tape library frontend if the storage type
                # is 'cold'
                if self.storageType == "cold":
                    self.tapeClient.connect()
                    self.tapeClient.migrate([ f["fullPath"] for f in blockFileList ], self.tapePool, self.jobId)
                    self.tapeClient.disconnect()

                blockFileList.clear()
                self.procBlocks += 1
                self.dbConn.updateProcessedBlocks(self.jobId, self.procBlocks)
        except Exception:
            self.logger.exception("FATAL: something went wrong while retrieving the data.")
            return False
        else:
            return True

    def execute(self):
        success = True
        self.logger.info("++++++++++ Start of execution phase ++++++++++")
        success &= self.buildFileList() & self.buildBlocks() & self.retrieveData()
        if success:
            self.logger.info("++++++++++ End of execution phase ++++++++++")
            return True
        else:
            self.logger.info("FATAL: something went wrong during the execution phase.")
            return False

    def update(self, status):
        """
        Updates the job status and sends an email to the user.
        """
        try:
            results = [{"target": ""}]
            results[0]["target"] = self.jobObj.jobInfo["transfer"]["target"]

            m = Mailer(self.logger)
            m.addRecipient(self.adminEmail)
            userEmail = self.dbConn.getUserEmail(self.jobObj.ownerId)
            if userEmail != self.adminEmail:
                m.addRecipient(userEmail)

            self.jobObj.setResults(results)

            # Add a list of physical destination paths for each VOSpace node in the node list
            self.logger.info("Generating physical destination paths for VOSpace nodes...")
            for vospacePath in self.nodeList:
                nodeInfo = self.dbConn.getOSPath(vospacePath)
                baseSrcPath = nodeInfo["baseSrcPath"]
                username = nodeInfo["username"]
                srcPath = nodeInfo["fullPath"]
                baseDestPath = self.storageRetrievePath.replace("{username}", username)
                destPath = srcPath.replace(baseSrcPath, baseDestPath)
                self.destPathList.append(destPath)
            self.jobObj.jobInfo["destPathList"] = self.destPathList.copy()

            if status == ("OK"):
                self.jobObj.setPhase("COMPLETED")
                self.jobObj.setEndTime(datetime.datetime.now().isoformat())
                self.dbConn.insertJob(self.jobObj)
                self.logger.info("Job phase updated to COMPLETED.")

                msg = f"""
        ########## VOSpace data retrieval procedure summary ##########

        Dear user,
        your job has been COMPLETED.

        Job ID: {self.jobId}
        Job type: {self.jobObj.type}
        Owner ID: {self.jobObj.ownerId}

        Your files are available and can be downloaded.

        """
                m.setMessage("VOSpace data retrieve notification: COMPLETED", msg)
            else:
                self.jobObj.setPhase("ERROR")
                self.jobObj.setErrorType("fatal")
                self.jobObj.setErrorMessage("FATAL: something went wrong during the execution phase.")
                self.jobObj.setEndTime(datetime.datetime.now().isoformat())
                self.dbConn.insertJob(self.jobObj)
                self.logger.info("Job phase updated to ERROR.")

                msg = f"""
        ########## VOSpace data retrieval procedure summary ##########

        Dear user,
        your job has FAILED.

        Job ID: {self.jobId}
        Job type: {self.jobObj.type}
        Owner ID: {self.jobObj.ownerId}
        """
                info = f"""
        ERROR:
        the job was terminated due to an error that occurred
        while retrieveing the data from the storage point.

        This issue will be automatically reported to the administrator.

        """
                msg += info
                m.setMessage("VOSpace data retrieve notification: ERROR", msg)
            # Send e-mail notification
            m.send()
        except Exception:
            self.logger.exception(f"FATAL: unable to update the database, job ID: {self.jobId}")

    def cleanup(self):
        """
        Cleanup method.
        """
        self.logger.info("Cleanup...")
        self.fileList.clear()
        self.nodeList.clear()
        self.destPathList.clear()
        self.storageType = None
        self.numBlocks = 0
        self.procBlocks = 0
        self.totalSize = 0

    def run(self):
        self.logger.info("Starting retrieve executor...")
        self.setSourceQueueName("read_ready")
        self.setDestinationQueueName("read_terminated")
        while True:
            self.wait()
            try:
                srcQueueLen = self.srcQueue.len()
                destQueueLen = self.destQueue.len()
            except Exception:
                self.logger.exception("Cache error: failed to retrieve queue length.")
            else:
                if srcQueueLen > 0:
                    self.jobObj = self.srcQueue.getJob()
                    self.jobId = self.jobObj.jobId
                    self.nodeList = self.jobObj.nodeList.copy()
                    if self.execute():
                        self.update("OK")

                        # debug block...
                        #print(f"fileList = {self.fileList}")
                        #print(f"nodeList = {self.nodeList}")
                    else:
                        self.update("ERROR")
                    self.cleanup()
                    try:
                        if destQueueLen >= self.maxTerminatedJobs:
                            self.destQueue.extractJob()
                        self.destQueue.insertJob(self.jobObj)
                        self.srcQueue.extractJob()
                    except Exception:
                        self.logger.exception(f"Failed to move job {self.jobObj.jobId} from '{self.srcQueue.name()}' to '{self.destQueue.name()}'")
                    else:
                        self.logger.info(f"Job {self.jobObj.jobId} MOVED from '{self.srcQueue.name()}' to '{self.destQueue.name()}'")