import base64
import logging
import os
from contextlib import closing
from typing import Union, List
import sqlalchemy as sqla
import yaml
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from assets.commons import validate_parameter


def get_credentials(logger: logging.Logger,
                    credentials_filename: Union[None, str] = None,
                    set_envs: bool = False) -> Union[None, dict]:
    """
    Retrieves credentials, if available, for logging in to the requested service.
    :param domain: the name of the service, as it appears in the credentials file (case sensitive!).
    :param logger: the logger to use.
    :param credentials_filename: the name of the credentials file to be used. If None, uses the default.
    :param set_envs: whether to set the credentials as environment variables.
    :return: a dictionary with username and password, as retrieved from the credentials file or None.
    """
    _credentials_filename = credentials_filename if credentials_filename is not None \
        else os.path.join('credentials', 'credentials.yml')
    try:
        with open(_credentials_filename) as credentials_file:
            credentials = yaml.load(credentials_file, Loader=yaml.FullLoader)
            credentials['DB_USER'] = base64.b64decode(credentials['DB_USER'].encode()).decode()
            credentials['DB_PASS'] = base64.b64decode(credentials['DB_PASS'].encode()).decode()
            if set_envs is True:
                for key in credentials:
                    os.environ[key] = str(credentials[key])

            return credentials
    except FileNotFoundError:
        logger.info('Credentials not found!')
        return None


def upsert(table_object: sqla.orm.decl_api.DeclarativeMeta,
           row_dict: dict,
           conflict_keys: List[sqla.Column],
           engine: sqla.engine):
    """
    Upsert the row into the specified table, according to the indicated conflict columns
    :param table_object: the table into which the row must be inserted
    :param row_dict: the dictionary representing the row to insert
    :param conflict_keys: the conflict columns list, representing the primary key of the table
    :param engine: the SQLAlchemy engine to use
    """
    statement = insert(table_object).values(row_dict)
    statement = statement.on_conflict_do_update(
        index_elements=conflict_keys, set_=row_dict)
    with closing(Session(engine)) as session:
        session.execute(statement)
        session.commit()


def get_pg_engine(logger: logging.Logger, engine_kwargs: Union[dict, None] = None) -> sqla.engine.Engine:
    """
    Return the SQLAlchemy engine, given the credentials in the external file
    :param logger: the logger to use
    :return: the SQLAlchemy engine
    """
    _kwargs = validate_parameter(engine_kwargs, default={})
    credentials = get_credentials(logger=logger,
                                  credentials_filename=os.path.join('credentials', 'db_credentials.yml'))
    url = f'postgresql://{credentials["DB_USER"]}:{credentials["DB_PASS"]}@{credentials["DB_HOST"]}:{credentials["DB_PORT"]}/{credentials["DB_NAME"]}'
    engine = sqla.create_engine(url, **_kwargs)
    return engine
