import base64 import logging import os from contextlib import closing from typing import Union, List import sqlalchemy as sqla import yaml from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session from assets.commons import validate_parameter def get_credentials(logger: logging.Logger, credentials_filename: Union[None, str] = None, set_envs: bool = False) -> Union[None, dict]: """ Retrieves credentials, if available, for logging in to the requested service. :param domain: the name of the service, as it appears in the credentials file (case sensitive!). :param logger: the logger to use. :param credentials_filename: the name of the credentials file to be used. If None, uses the default. :param set_envs: whether to set the credentials as environment variables. :return: a dictionary with username and password, as retrieved from the credentials file or None. """ _credentials_filename = credentials_filename if credentials_filename is not None \ else os.path.join('credentials', 'credentials.yml') try: with open(_credentials_filename) as credentials_file: credentials = yaml.load(credentials_file, Loader=yaml.FullLoader) credentials['DB_USER'] = base64.b64decode(credentials['DB_USER'].encode()).decode() credentials['DB_PASS'] = base64.b64decode(credentials['DB_PASS'].encode()).decode() if set_envs is True: for key in credentials: os.environ[key] = str(credentials[key]) return credentials except FileNotFoundError: logger.info('Credentials not found!') return None def upsert(table_object: sqla.orm.decl_api.DeclarativeMeta, row_dict: dict, conflict_keys: List[sqla.Column], engine: sqla.engine): """ Upsert the row into the specified table, according to the indicated conflict columns :param table_object: the table into which the row must be inserted :param row_dict: the dictionary representing the row to insert :param conflict_keys: the conflict columns list, representing the primary key of the table :param engine: the SQLAlchemy engine to use """ statement = insert(table_object).values(row_dict) statement = statement.on_conflict_do_update( index_elements=conflict_keys, set_=row_dict) with closing(Session(engine)) as session: session.execute(statement) session.commit() def get_pg_engine(logger: logging.Logger, engine_kwargs: Union[dict, None] = None) -> sqla.engine.Engine: """ Return the SQLAlchemy engine, given the credentials in the external file :param logger: the logger to use :return: the SQLAlchemy engine """ _kwargs = validate_parameter(engine_kwargs, default={}) credentials = get_credentials(logger=logger, credentials_filename=os.path.join('credentials', 'db_credentials.yml')) url = f'postgresql://{credentials["DB_USER"]}:{credentials["DB_PASS"]}@{credentials["DB_HOST"]}:{credentials["DB_PORT"]}/{credentials["DB_NAME"]}' engine = sqla.create_engine(url, **_kwargs) return engine