import glob
import os
import uuid
import sqlalchemy
import argparse
import sys
from multiprocessing import Pool
from stg.stg_build_db_structure import init_db, TmpExecutionQueue
from itertools import product, chain
from typing import Union, Tuple, Iterator
from assets.commons import (cleanup_directory,
                            setup_logger,
                            validate_parameter)
from assets.commons.parsing import parse_input_main
from assets.commons.db_utils import upsert, get_pg_engine
from stg.stg_radmc_input_generator import main as stg_main
from mdl.mdl_execute_radmc_command import main as execute_radmc_script
from prs.prs_compute_integrated_fluxes_and_ratios import main as prs_main
from prs.prs_inspect_results import main as prs_inspection_main


def compute_full_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword) -> Tuple[float, float, int, str]:
    scratch_dir = os.path.join('mdl', 'scratches', str(uuid.uuid4()))
    stg_overrides = {
        'grid': {
            dust_temperature_keyword: tdust,
            density_keyword: nh2,
        }
    }
    overrides = {
        'grid_lines': stg_overrides,
        'model': {
            'radmc_observation': {
                'iline': line
            }
        }
    }
    tarname = stg_main(override_config=overrides,
                       path_radmc_files=scratch_dir,
                       run_id=run_id)
    cube_fits_name = execute_radmc_script(grid_zipfile=tarname,
                                          override_config=overrides,
                                          radmc_input_path=scratch_dir,
                                          run_id=run_id)
    return tdust, nh2, line, cube_fits_name


def initialize_queue(engine: sqlalchemy.engine,
                     run_id: str,
                     run_arguments: Iterator):
    is_initialized = engine.execute(f"select count(*) from tmp_execution_queue where run_id='{run_id}'").first()[0] != 0
    if is_initialized is False:
        for arguments in run_arguments:
            raw_insert_entry = {'run_id': run_id,
                                'dust_temperature': arguments[0],
                                'density': arguments[1],
                                'line': arguments[2],
                                'density_keyword': arguments[3],
                                'dust_temperature_keyword': arguments[4],
                                'done': False}
            upsert(
                table_object=TmpExecutionQueue,
                row_dict=raw_insert_entry,
                conflict_keys=[
                    TmpExecutionQueue.run_id,
                    TmpExecutionQueue.dust_temperature,
                    TmpExecutionQueue.density,
                    TmpExecutionQueue.line,
                    TmpExecutionQueue.density_keyword,
                    TmpExecutionQueue.dust_temperature_keyword
                ],
                engine=engine
            )


def get_run_pars(engine: sqlalchemy.engine,
                 run_id: str):
    sql_query = sqlalchemy.text(f"""UPDATE tmp_execution_queue 
                SET done = true 
                WHERE row_id = (SELECT row_id 
                                   FROM tmp_execution_queue
                                   WHERE (run_id = '{run_id}')
                                        AND (done is false)
                                        AND pg_try_advisory_xact_lock(row_id) 
                                   LIMIT 1 FOR UPDATE) 
                RETURNING *""")
    return engine.execution_options(autocommit=True).execute(sql_query).first()


def verify_run(engine: sqlalchemy.engine,
               run_id: str):
    sql_query = sqlalchemy.text(f"""UPDATE tmp_execution_queue 
                SET done = false
                WHERE row_id in (SELECT row_id 
                                   FROM tmp_execution_queue
                                   WHERE run_id = '{run_id}'
                                        AND (done is true)
                                        AND (fits_cube_name is null))""")
    engine.execution_options(autocommit=True).execute(sql_query)
    _remaining_models = compute_remaining_models(run_id=run_id)
    return True if _remaining_models == 0 else False


def insert_fits_name(engine: sqlalchemy.engine,
                     row_id: int,
                     fits_cube_name: str):
    sql_query = sqlalchemy.text(f"""UPDATE tmp_execution_queue 
                SET fits_cube_name = '{fits_cube_name}'
                WHERE row_id = {row_id}""")
    engine.execution_options(autocommit=True).execute(sql_query)


def compute_grid_elements(run_id: str):
    init_db()
    parallel_args, _ = get_parallel_args_and_nprocesses()
    engine = get_pg_engine(logger=logger)
    initialize_queue(engine=engine,
                     run_id=run_id,
                     run_arguments=parallel_args)
    engine.dispose()


def get_parallel_args_and_nprocesses() -> Tuple[Iterator, int]:
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs, n_processes, _ = parse_input_main()
    line_set = set(chain.from_iterable(line_pairs))
    density_keyword = 'central_density' if _model_type == 'homogeneous' else 'density_at_reference'
    dust_temperature_keyword = 'dust_temperature' if _tdust_model_type == 'isothermal' else 'dust_temperature_at_reference'
    parallel_args = product(dust_temperatures, densities, line_set, [density_keyword], [dust_temperature_keyword])
    return parallel_args, n_processes


def compute_model(run_id: str):
    engine = get_pg_engine(logger=logger)
    parameters_set = get_run_pars(engine=engine,
                                  run_id=run_id)
    engine.dispose()
    if parameters_set is not None:
        _, _, _, fits_cube_name = compute_full_grid(tdust=parameters_set[2],
                                                    nh2=parameters_set[3],
                                                    line=parameters_set[4],
                                                    density_keyword=parameters_set[5],
                                                    dust_temperature_keyword=parameters_set[6])
        engine = get_pg_engine(logger=logger)
        insert_fits_name(engine=engine,
                         row_id=parameters_set[0],
                         fits_cube_name=fits_cube_name)
        engine.dispose()
    else:
        logger.info('All models were completed.')


def initialize_run():
    if args.run_id is not None:
        run_id = args.run_id
    else:
        logger.info('Generating new run_id')
        run_id = str(uuid.uuid4())
    compute_grid_elements(run_id=run_id)
    sys.stdout.write(run_id)
    with open('run_id.txt', 'w') as run_id_file:
        run_id_file.write(f'{run_id}\n')
    return run_id


def compute_remaining_models(run_id: Union[None, str] = None) -> int:
    _run_id = validate_parameter(run_id, default=os.getenv('run_id'))
    logger.info(_run_id)
    sql_query = sqlalchemy.text(f"""SELECT count(*)
                                    FROM tmp_execution_queue
                                    WHERE (run_id = '{run_id}')
                                        AND ((done is false)
                                        OR ((done is true) AND (fits_cube_name is null)))""")
    engine = get_pg_engine(logger=logger)
    n_models = engine.execution_options(autocommit=True).execute(sql_query).first()[0]
    engine.dispose()
    sys.stdout.write(str(n_models))
    return n_models


def get_results(engine: sqlalchemy.engine,
                run_id: str):
    sql_query = sqlalchemy.text(f"""SELECT dust_temperature
                                           , density
                                           , line
                                           , fits_cube_name
                                    FROM tmp_execution_queue
                                    WHERE run_id = '{run_id}'""")
    return engine.execution_options(autocommit=True).execute(sql_query).all()


def cleanup_tmp_table(run_id: str,
                      engine: sqlalchemy.engine):
    sql_query = sqlalchemy.text(f"""DELETE
                                    FROM tmp_execution_queue
                                    WHERE run_id = '{run_id}'""")
    return engine.execution_options(autocommit=True).execute(sql_query)


def main_presentation_step(run_id: str,
                           cleanup_scratches: bool = True,
                           results_dict: Union[dict, None] = None) -> bool:
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs, n_processes, run_type = parse_input_main()

    for folder in ('data', 'figures', 'trained_model'):
        os.makedirs(os.path.join(
            'prs',
            'output',
            'run_type',
            run_type,
            folder
        ), exist_ok=True)

    engine = get_pg_engine(logger=logger)
    _results_dict = validate_parameter(results_dict,
                                       default=get_results(engine=engine,
                                                           run_id=run_id))

    results_map = {}
    for (tdust, nh2, line, cube_fits_name) in _results_dict:
        results_map[f'{str(nh2)}_{str(tdust)}_{line}'] = cube_fits_name

    for line_pair in line_pairs:
        for tdust, nh2 in product(dust_temperatures, densities):
            prs_main(cube_fits_list=[results_map[f'{str(nh2)}_{str(tdust)}_{line_pair[0]}'],
                                     results_map[f'{str(nh2)}_{str(tdust)}_{line_pair[1]}']],
                     run_id=run_id,
                     engine=engine)

    if cleanup_scratches is True:
        scratches_dirs = glob.glob(os.path.join('mdl', 'scratches', '*'))
        for scratches in scratches_dirs:
            cleanup_directory(directory=scratches, logger=logger)

    prs_inspection_main(run_id=run_id,
                        is_isothermal=_tdust_model_type == 'isothermal',
                        is_homogeneous=_model_type == 'homogeneous',
                        engine=engine,
                        run_type=run_type)
    _run_success = verify_run(run_id=run_id,
                              engine=engine)
    if _run_success is True:
        cleanup_tmp_table(run_id=run_id,
                          engine=engine)
    engine.dispose()
    return _run_success


def process_models(distributed: bool = False) -> Tuple[Union[None, dict], int]:
    if distributed is True:
        compute_model(run_id=run_id)
        results = None
        remaining_models = compute_remaining_models(run_id)
    else:
        parallel_args, n_processes = get_parallel_args_and_nprocesses()
        with Pool(n_processes) as pool:
            results = pool.starmap(compute_full_grid, parallel_args)
        remaining_models = 0
    return results, remaining_models


logger = setup_logger(name='MAIN')
parser = argparse.ArgumentParser()
parser.add_argument('--run_id')
parser.add_argument('--cleanup_scratches')
parser.add_argument('--distributed')
args = parser.parse_args()

if __name__ == '__main__':
    run_id = initialize_run()
    assert run_id is not None
    _distributed = validate_parameter(args.distributed, default='false').lower() == 'true'

    results, remaining_models = process_models(distributed=_distributed)
    if remaining_models == 0:
        logger.info('All grid points processed. Summarizing results.')
        _cleanup = validate_parameter(args.cleanup_scratches,
                                      default='true').lower() == 'true'
        _run_success = main_presentation_step(run_id=run_id,
                                              cleanup_scratches=_cleanup,
                                              results_dict=results)
        if (_run_success is True) or (_distributed is False):
            logger.info('The run completed successfully!')
        else:
            logger.error('The run was incomplete. I have reset the done flag in the database for incomplete models')
