import os
import numpy as np
import argparse
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger)
from assets.commons.training_utils import train_models_wrapper

if __name__ == '__main__':
    external_input = load_config_file(config_file_path='config/ml_modelling.yml')
    logger = setup_logger(name='PRS - ML training')

    parser = argparse.ArgumentParser()
    parser.add_argument('--line', default=None)
    parser.add_argument('--distributed', default='true')
    args = parser.parse_args()
    try:
        limit_rows = external_input['limit_rows']
    except KeyError:
        limit_rows = None
    try:
        use_model_for_training = external_input['use_model_for_training']
    except KeyError:
        use_model_for_training = None
    _distributed = validate_parameter(args.distributed, default='false').lower() == 'true'

    _use_model_for_training = validate_parameter(
        use_model_for_training,
        default='constant_abundance_p15_q05'
    )
    _model_root_folder = os.path.join('prs', 'output', 'run_type', _use_model_for_training)
    logger.info(f'Using {_use_model_for_training} for training in folder {_model_root_folder}')

    if external_input['retrain'] is True:
        if external_input['model_type'] == 'XGBoost_gridsearch':
            _model_kwargs = external_input['model_parameters']['param_grid'].copy()
            for par in _model_kwargs:
                _model_kwargs[par] = np.arange(_model_kwargs[par][0], _model_kwargs[par][1], _model_kwargs[par][2])
            external_input['model_parameters']['param_grid'] = _model_kwargs
        if _distributed is True:
            _line = validate_parameter(args.line, default='86')
            if args.line is None:
                logger.warning(f'No line ID specified, proceeding with {_line}')
            train_models_wrapper(target_name=f'mom_zero_{_line}',
                                 model_type=external_input['model_type'],
                                 model_kwargs=external_input['model_parameters'],
                                 limit_rows=limit_rows,
                                 model_root_folder=_model_root_folder,
                                 use_validation=external_input['use_validation'],
                                 logger=logger)
        else:
            lines_to_process = (args.line,) if (args.line is not None) \
                else ('86', '87', '88', '256', '257', '381', '380')
            for line in lines_to_process:
                train_models_wrapper(target_name=f'mom_zero_{line}',
                                     model_type=external_input['model_type'],
                                     model_kwargs=external_input['model_parameters'],
                                     limit_rows=limit_rows,
                                     model_root_folder=_model_root_folder,
                                     use_validation=external_input['use_validation'],
                                     logger=logger)

    logger.info(f'Completed training for line {args.line}')
