import logging
import os
import time

from sqlalchemy import (Column,
                        ForeignKey,
                        Integer,
                        Sequence,
                        String,
                        Float,
                        DateTime,
                        Boolean,
                        ARRAY,
                        ForeignKeyConstraint)
from assets.commons import (setup_logger)
from assets.commons.db_utils import get_pg_engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.exc import OperationalError

logger = setup_logger(name='DB_SETUP')
Base = declarative_base()


class GridFiles(Base):
    __tablename__ = "grid_files"
    __table_args__ = (
        ForeignKeyConstraint(
            ('zipped_grid_name', 'run_id'),
            ['grid_parameters.zipped_grid_name', 'grid_parameters.run_id']
        ),
    )
    zipped_grid_name = Column(String(150), primary_key=True)
    quantity = Column(String(30), primary_key=True)
    fits_grid_name = Column(String)
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class GridPars(Base):
    __tablename__ = "grid_parameters"
    zipped_grid_name = Column(String(150), primary_key=True)
    species_and_partners = relationship("SpeciesAndPartners", cascade="all, delete-orphan")
    fits_cube_name = relationship("ModelPars", cascade="all, delete-orphan")
    line_pars = relationship("LinePars", cascade="all, delete-orphan")
    stars_pars = relationship("StarsPars", cascade="all, delete-orphan")
    grid_type = Column(String)
    coordinate_system = Column(String)
    central_density = Column(Float)
    density_powerlaw_index = Column(Float)
    density_at_reference = Column(Float)
    dust_temperature = Column(Float)
    dust_temperature_powerlaw_index = Column(Float)
    dust_temperature_at_reference = Column(Float)
    microturbulence = Column(Float)
    velocity_field = Column(String)
    velocity_gradient = Column(String)
    velocity_powerlaw_index = Column(Float)
    velocity_at_reference = Column(Float)
    distance_reference = Column(Float)
    maximum_radius = Column(Float)
    grid_size_1 = Column(Float)
    grid_shape_1 = Column(Float)
    grid_refpix_1 = Column(Float)
    grid_size_2 = Column(Float)
    grid_shape_2 = Column(Float)
    grid_refpix_2 = Column(Float)
    grid_size_3 = Column(Float)
    grid_shape_3 = Column(Float)
    grid_refpix_3 = Column(Float)
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class StarsPars(Base):
    __tablename__ = "stars_parameters"
    __table_args__ = (
        ForeignKeyConstraint(
            ('zipped_grid_name', 'run_id'),
            ['grid_parameters.zipped_grid_name', 'grid_parameters.run_id']
        ),
    )
    zipped_grid_name = Column(String(150), primary_key=True)
    nstars = Column(Integer)
    rstars = Column(ARRAY(Float))
    mstars = Column(ARRAY(Float))
    star_positions = Column(ARRAY(Float))
    star_fluxes = Column(ARRAY(Float))
    nlambdas = Column(Integer)
    spacing = Column(String)
    lambdas_micron_limits = Column(ARRAY(Float))
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class LinePars(Base):
    __tablename__ = "lines_parameters"
    __table_args__ = (
        ForeignKeyConstraint(
            ('zipped_grid_name', 'run_id'),
            ['grid_parameters.zipped_grid_name', 'grid_parameters.run_id']
        ),
    )
    zipped_grid_name = Column(String(150), primary_key=True)
    lines_mode = Column(String(20))
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class SpeciesAndPartners(Base):
    __tablename__ = "species_and_partners"
    __table_args__ = (
        ForeignKeyConstraint(
            ('zipped_grid_name', 'run_id'),
            ['grid_parameters.zipped_grid_name', 'grid_parameters.run_id']
        ),
    )
    zipped_grid_name = Column(String(150), primary_key=True)
    species_to_include = Column(String(100), primary_key=True)
    molecular_abundance = Column(Float)
    threshold = Column(Float)
    abundance_jump = Column(Float)
    collision_partner = Column(String(100), primary_key=True)
    molecular_abundance_collision_partner = Column(Float)
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class ModelPars(Base):
    __tablename__ = "model_parameters"
    __table_args__ = (
        ForeignKeyConstraint(
            ('zipped_grid_name', 'run_id'),
            ['grid_parameters.zipped_grid_name', 'grid_parameters.run_id']
        ),
    )
    zipped_grid_name = Column(String(150), nullable=False)
    fits_cube_name = Column(String(150), primary_key=True)
    mom_zero_name = relationship("MomentZeroMaps", cascade="all, delete-orphan")
    nphotons = Column(Float)
    scattering_mode_max = Column(Integer)
    iranfreqmode = Column(Integer)
    tgas_eq_tdust = Column(Integer)
    inclination = Column(Float)
    position_angle = Column(Float)
    imolspec = Column(Integer)
    iline = Column(Integer)
    width_kms = Column(Float)
    nchannels = Column(Integer)
    npix = Column(Integer)
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class MomentZeroMaps(Base):
    __tablename__ = "moment_zero_maps"
    __table_args__ = (
        ForeignKeyConstraint(
            ('fits_cube_name', 'run_id'),
            ['model_parameters.fits_cube_name', 'model_parameters.run_id']
        ),
    )
    mom_zero_name = Column(String(150), primary_key=True)
    fits_cube_name = Column(String(150), nullable=False)
    integration_limit_low = Column(Float)
    integration_limit_high = Column(Float)
    aggregated_moment_zero = Column(Float)
    aggregation_function = Column(String(20))
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)


class RatioMaps(Base):
    __tablename__ = "ratio_maps"
    __table_args__ = (
        ForeignKeyConstraint(
            ('mom_zero_name_1', 'run_id'),
            ['moment_zero_maps.mom_zero_name', 'moment_zero_maps.run_id'],
        ),
        ForeignKeyConstraint(
            ('mom_zero_name_2', 'run_id'),
            ['moment_zero_maps.mom_zero_name', 'moment_zero_maps.run_id'],
        ),
    )
    ratio_map_name = Column(String(150), primary_key=True)
    mom_zero_name_1 = Column(String(150), nullable=False)
    mom_zero_name_2 = Column(String(150), nullable=False)
    aggregated_ratio = Column(Float)
    aggregation_function = Column(String(20))
    created_on = Column(DateTime)
    run_id = Column(String, primary_key=True)
    mom_zero_map_1 = relationship("MomentZeroMaps", foreign_keys=[mom_zero_name_1, run_id])
    mom_zero_map_2 = relationship("MomentZeroMaps", foreign_keys=[mom_zero_name_2, run_id])


class TmpExecutionQueue(Base):
    __tablename__ = "tmp_execution_queue"
    row_id = Column(Integer, Sequence('row_id_seq'))
    run_id = Column(String, primary_key=True)
    dust_temperature = Column(Float, primary_key=True)
    density = Column(Float, primary_key=True)
    line = Column(Integer, primary_key=True)
    density_keyword = Column(String, primary_key=True)
    dust_temperature_keyword = Column(String, primary_key=True)
    fits_cube_name = Column(String)
    done = Column(Boolean)


def init_db():
    engine = get_pg_engine(logger=logger)
    try:
        Base.metadata.create_all(bind=engine)
    except OperationalError:
        logger.error('Connection failed. Sleeping and retrying')
        time.sleep(3)
        Base.metadata.create_all(bind=engine)
    logger.info('Connection successful! DB initialized as needed.')
    engine.dispose()


if __name__ == '__main__':
    init_db()
