#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 17 14:32:22 2021

@author: smordini

useful functions for sed fitting
"""
import numpy as np
from scipy import interpolate
import math


def match(list_a,list_b):
    matched_valuse=list(set(list_a) & set(list_b))
    pos_a=[list_a.index(val) for val in matched_valuse]
    pos_b=[list_b.index(val) for val in matched_valuse]
    return(pos_a,pos_b)

def tostring(in_str):
    
    if type(in_str)!=str:
        in_str=str(in_str)
    
    out_str=in_str.strip()
    return out_str


def pad0_num(num):
    
    
    if type(num)==str:
        try:
            float(num)
        except:
            print('The string must be a number. Retry')
    if type(num)!=str:
        num=str(num)
        
    q=num.find('.')
    if q<=len(num)-2 and q!=-1:
        num0=num+'0'
    else:
        num0=num


    return(num0)

def remove_char(in_str,ch):
        
    if type(in_str)!=str:
        in_str=str(in_str)
    q=in_str.find(ch)
    if q==-1:
        out_str=in_str
    else:
        out_str=in_str.replace(ch, '')
    return out_str

def planck(wavelength, temp):
    
    try:
        len(wavelength)
    except:
        wavelength=[wavelength]
    try:
        len(temp)
    except:
        pass
    else:
        print('Only one temperature allowed, not list')
        return
    if len(wavelength)<1:
        print('Enter at least one wavelength in Angstrom.')
        return

    
    wave=[ wl*1e-8 for wl in wavelength]
    c1 =  3.7417749e-5 #2*pi*h*c*c with constatns in cgs units (TO BE CHECKED?)
    c2 =  1.4387687    # h*c/k
    value = [c2/wl/temp for wl in wave]
    # print(value)
    bbflux=[]
    test_precision=math.log(np.finfo(float).max)
    # print(test_precision)
    for val,wl in zip (value, wave):   
        if val<test_precision:
            flux=c1/(wl**5*(math.expm1(val)))
            # print(flux)
            bbflux.append(flux*1e-8)
        else:
            bbflux.append(0)
    if len(bbflux)==1:
        bbflux=bbflux[0]
    
    
    
    return bbflux


def lbol(wavelength,flux,dist):
#function lbol,w,f,d

# interpolation between data points is done in logarithmic space to allow 
# straight lines (the SED is like that) in the log-log plot. interpolation 
# on a finer grid is done and then data are transformed back in linear space
# where the trapezoidal interpolation is then done

# w is lambda in um
# f is flux jy
# d is dist in parsec

#convert flux to W/cm2/um

    fw=[1.e-15*f*.2997/(w**2.) for f,w in zip(flux,wavelength)]
    lw=[np.log10(w) for w in wavelength]
    lfw=[np.log10(f) for f in fw]
    #1000 points resampling
    
    lw1=(np.arange(1000)*((max(lw)-min(lw))/1000.))+min(lw)
    interpfunc = interpolate.interp1d(lw,lfw, kind='linear')
    lfw1=interpfunc(lw1)
    w1=[10**l for l in lw1]
    fw1=[10**l for l in lfw1]
    jy=[f/1.e-15/.3*(w**2.) for f,w in zip (fw1,w1)]
#    ;integrate over whole range
    fint=0.
    # for i=0,n_elements(w1)-2 do fint=fint+((fw1(i)+fw1(i+1))*(w1(i+1)-w1(i))/2.):
    for i in range(len(w1)-2):
        fint=fint+((fw1[i]+fw1[(i+1)])*(w1[(i+1)]-w1[(i)])/2.)
    
#    ;fint=int_tabulated(w,fw,/double,/sort)
    # ; integrate longword of 350um
    # ;qc0=where(w ge 350.)
    # ;fc0=int_tabulated(w(qc0),fw(qc0))
    # ; compute lbol
    # l=fint*4.*math.pi*d*3.e18/3.8e26*d*3.e18   ;lsol
    lum=fint*4*np.pi*dist*3.086e18/3.827e26*dist*3.086e18   #lsol
    # ;c0ratio=fc0/fint
    # ;print,'Lsubmm/Lbol = ',c0ratio
    
    return lum

def wheretomulti(array, indices):
    s=array.shape
    # print(s)
    NCol=s[1]
    Col=indices % NCol
    
    if len(s)==2:
        Row=indices/NCol
        return (Col, Row)
    elif len(s)==3:
        NRow  = s[2]
        Row   = ( indices / NCol ) % NRow
        Frame = indices / ( NRow * NCol )
        return(Col, Row, Frame)
    else:
        Col=0
        Row=0
        Frame=0
        print('WhereToMulti called with bad input. Array not a vector or matrix.')
        return(Col, Row, Frame)
    return
