from fitslib import *
from uvlib import *
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from operator import itemgetter
import glob
import os, sys, time
from scipy.optimize import least_squares
import emcee, corner

class fitter(object):
    """ The main class to handle the fitting of models, printing of results and saving of residuals."""
    def __init__(self, fitsimage, srcat, stampsize = (256,256)):

        self.fitsimage = fitsimage
        self.thresh = 3*self.fitsimage.imageRMS
        self.beam_pix = {}
        self.beam_pix['maj'] = self.fitsimage.bmaj_pix
        self.beam_pix['min'] = self.fitsimage.bmin_pix
        self.beam_pix['pa'] = self.fitsimage.beam_deg[2] # in degrees

        error_pix = 2.0/self.fitsimage.pix2mas # use as error margin
        self.error_ra_pix = error_pix
        self.error_dec_pix = error_pix

        self.srcat = srcat
        self.nrow = stampsize[0]
        self.ncol = stampsize[1]
        self.uvmesh = build_uvmesh(self.nrow, self.ncol)
        self.uvbeam = twoDgaussian_uv(self.uvmesh, 1.0, 0, 0, self.fitsimage.beam_sigma_maj_pix, self.fitsimage.beam_sigma_min_pix, self.beam_pix['pa']*np.pi/180.0, mode = 'beam')

        self.plotdir = os.path.dirname(self.fitsimage.path) + '/fitplots/'
        os.system('mkdir -p ' + self.plotdir)

    def fit_all(self):
        print('Fitting ' + str(len(self.srcat)) + ' sources...')
        allres = []
        for i, sour in enumerate(self.srcat):
            res = self.fit_single(sour)
            print('Appending res ', res)
            allres.append(res)
        # Convert to numpy arrays for easy slicing below
        self.allres = np.array(allres)

    def fit_single(self, sour):
        # Read data guess for this source
        name = sour[0]
        ra = float(sour[1])
        dec = float(sour[2])
        try: 
            stamp = self.fitsimage.get_stamp(ra, dec, self.nrow, self.ncol, mode = 'residual')
        except ValueError:
            print('Could not obtain stamp. This source is likely outside the image. This is normal for U-band which is zoomed. ')
            stamp = None
        
        if stamp is None:
            print("SKIPPING ", name, " SINCE OUTSIDE IMAGE.")
            res = [0, 0, 0, 0]
        else:
            crow = 0.5*self.nrow
            wrow = 0.5*self.error_dec_pix
            ccol = 0.5*self.ncol
            wcol = 0.5*self.error_ra_pix
            rl = int(crow - wrow)
            ru = int(crow + wrow)
            cl = int(ccol - wcol)
            cu = int(ccol + wcol)
            testarr = stamp[rl:ru, cl:cu]
            # Check so that there is a peak of at least thresh in the central
            #region bounded by the error_pix values. If not, this source should
            #not be fitted since the fit may be unreliable.
            if np.max(testarr)>self.thresh:
                print('FITTING ' + name + '...')
                start = time.time()

                # Fit shell to data
                guess = [0.5e-3, 0.5*self.nrow, 0.5*self.ncol, 1]
                lower_bounds = [0, crow-wrow, ccol-wcol, 0]
                upper_bounds = [np.inf, crow+wrow, ccol+wcol, 8/self.fitsimage.pix2mas]
                fitres = least_squares(optf, guess, bounds = (lower_bounds, upper_bounds), args = (stamp, self.uvmesh, self.uvbeam, self.fitsimage.imageRMS), verbose=0)
                # Save least squares results
                res = fitres['x']
                res[1] = res[1]-crow
                res[2] = res[2]-ccol

                # Construct model from fitted values with same dimensions as stamp
                resscale = [res[0], (res[1]+crow), (res[2]+ccol), res[3]] 
                resuv = optically_thin_shell_uv(self.uvmesh, *resscale)
                resim = np.real(np.fft.ifft2(resuv*self.uvbeam))
                # Subtract fitted model from residual image to simplify fitting of weaker sources
                self.fitsimage.subtract_model_stamp(ra, dec, resim)
                
            else:
                print('TOO WEAK TO FIT: ' + name + ': fit not attempted.')
                res = [0, 0, 0, 0]
        return res

    def plot_all(self, save = True):
        for i, sour in enumerate(self.srcat):
            self.plot_single(sour, self.alllsres[i], save = save)

    def plot_single(self, sour, res, save = False):
        name = sour[0]
        print('Saving plot of source', name)
        ra = float(sour[1])
        dec = float(sour[2])
        try: 
            residualstamp = self.fitsimage.get_stamp(ra, dec, self.nrow, self.ncol, mode = 'residual')
        except ValueError:
            print('Could not obtain stamp. This source is likely outside the image. This is normal for U-band which is zoomed. ')
            residualstamp = None

        if residualstamp is not None:
            bunit = 1e6
            deg2mas = 3600.0*1000
            dra = self.fitsimage.dra
            ddec = self.fitsimage.ddec
            crow = 0.5*self.nrow
            ccol = 0.5*self.ncol
            extent=np.array([ccol*dra, -ccol*dra, -crow*ddec, crow*ddec])*deg2mas

            f, (ax1, ax2, ax3) = plt.subplots(1,3, sharey=True)
            rawstamp = self.fitsimage.get_stamp(ra, dec, self.nrow, self.ncol, mode = '')
            if res[0]>0:
                # Shift position from stored relative to center to absolute pixels by adding center pixels
                args = res
                args[1]+=crow
                args[2]+=ccol
                shelluv = optically_thin_shell_uv(self.uvmesh, *args)
                shellim = np.real(np.fft.ifft2(shelluv*self.uvbeam))
            else:
                shellim = np.zeros_like(rawstamp)
            rawim = ax1.imshow(bunit*rawstamp, origin= 'lower left', interpolation = 'None', extent = extent)
            fitim = ax2.imshow(bunit*shellim, origin = 'lower left', interpolation = 'None', extent = extent)
            resim = ax3.imshow(bunit*residualstamp, origin= 'lower left', interpolation = 'None', extent = extent)
            #ax1.set_title('Raw')
            #ax2.set_title('Fit')
            #ax3.set_title('Residual')
            div1 = make_axes_locatable(ax1)
            cax1 = div1.append_axes("top", size="10%", pad=0.05)
            div2 = make_axes_locatable(ax2)
            cax2 = div2.append_axes("top", size="10%", pad=0.05)
            div3 = make_axes_locatable(ax3)
            cax3 = div3.append_axes("top", size="10%", pad=0.05)
            f.colorbar(rawim, ax=ax1, cax = cax1, orientation="horizontal", label = r'$\mu$Jy/beam')
            f.colorbar(fitim, ax=ax2, cax = cax2, orientation="horizontal", label = r'$\mu$Jy/beam')
            f.colorbar(resim, ax=ax3, cax = cax3, orientation="horizontal", label = r'$\mu$Jy/beam')
            for cax in [cax1, cax2, cax3]:
                cax.xaxis.set_ticks_position('top')
                cax.xaxis.set_label_position('top')
                cax.set_xticklabels(cax.get_xticklabels(), rotation='vertical')
            ax1.set_ylabel('Rel. Dec. [mas]')
            for ax in [ax1, ax2, ax3]:
                ax.set_xlabel('Rel. R.A. [mas]')
                ax.plot(extent[0:2],[0, 0], 'w--', linewidth = 0.5)
                ax.plot([0, 0],extent[2:4], 'w--', linewidth = 0.5)
            plotfile = self.plotdir + self.fitsimage.experiment + '_' + self.fitsimage.band[0] + '_' + self.fitsimage.nucleus + '_' +name
            #f.tight_layout()
            f.savefig(plotfile + '_FIT.pdf', bbox_inches='tight',dpi=300)
            #f.savefig(plotfile + '_FIT.png', bbox_inches='tight',dpi=300)
            plt.close(f)

    def write_results(self):
        """ Save fitted results to text file."""
         
        # Define path to outfile, i.e. full filename except last append and .png, appended below
        outfile = self.fitsimage.path + '.dat'
        print('Saving fit results to file ' + outfile)
        if len(self.allres)>0:
            deg2mas = 3600.0*1000

            flux = self.allres[:,0]
            row = self.allres[:,1]*self.fitsimage.pix2mas 
            col = self.allres[:,2]*self.fitsimage.pix2mas 
            size = self.allres[:,3]*self.fitsimage.pix2mas

            of = open(outfile, 'w')
            of.write('EXPERIMENT=' + self.fitsimage.experiment + '\n')
            of.write('NUCLEUS=' + (self.fitsimage.nucleus) + '\n')
            of.write('DATE='+self.fitsimage.date + '\n')
            of.write('FREQ=' + str(repr(self.fitsimage.freq)) + '#Hz\n')
            of.write('MAPRMS=' + str(repr(self.fitsimage.imageRMS)) + '#Jy/beam\n')
            of.write('THRESHHOLD=' + str(repr(self.thresh)) + '#Jy/beam\n')
            of.write('BMAJ=' + str(self.fitsimage.beam_deg[0]*deg2mas) + '#major beam axis in milliarcsec\n')
            of.write('BMIN=' + str(self.fitsimage.beam_deg[1]*deg2mas) + '#minor beam axis in milliarcsec\n')
            of.write('BPA=' + str(self.fitsimage.beam_deg[2]) + '#beam position angle in degrees\n')
            of.write('NAME FLUX[Jy] DIAMETER[mas] DX[mas] DY[mas]\n')
            for i, sour in enumerate(self.srcat):
                name = sour[0]
                of.write(name + ' ')
                f = flux[i]
                s = size[i]
                dx = col[i]
                dy = row[i]
                of.write(str(repr(f)) + ' ')
                of.write(str(repr(s)) + ' ')
                of.write(str(repr(dx)) + ' ')
                of.write(str(repr(dy)) + '\n')
            of.close()

def read_source_catalog(catfile):
    """ Read a source catalog produced by PyBDSM."""
    cat = []
    for line in open(catfile):
        if not line.startswith('#'):
            source = line.strip().split(' ')
            name = source[0]
            ra = float(source[1]) # Deg
            dec = float(source[2]) # Deg
            smaj = float(source[3]) # Deg
            flux = float(source[4]) # Jy
            cat.append([name, ra, dec, flux, smaj])
    # Sort catalogue according to flux density, thereby starting to fit the brightest sources
    # This is not always true since flux changes between epocs, but at least a rough guess
    cat = sorted(cat, key=itemgetter(3), reverse = True)
    return cat

exp = '*'
band = '*'
#exp = 'BB335B'
#band = 'C'
path = '../data/'+exp+'_'+band+'/*_'+band+'*ST_IMAGE.FITS'
images = glob.glob(path)

srcat = read_source_catalog('./arp220_source_catalog.txt')

for image in images:
    print('FITTING ' + image + '...')
    im2fit = fitsimage(image)
    sr2fit = []
    for sour in srcat:
        # Determine nucleus based on R.A.
        if float(sour[0][0:4]) > 0.25:
            n = 'EAST'
        else:
            n = 'WEST'
        # If same nucleus as this image we are reading, add the source to the fitting list
        if n == im2fit.nucleus:
            sr2fit.append(sour)
            #if sour[0]=='0.2212+0.444':
            #    sr2fit.append(sour)
    if im2fit.band[0] in ['L','S','U']:
        stampsize = (256,256)
    else:
        stampsize = (128,128)

    fit = fitter(im2fit, sr2fit, stampsize = stampsize)
    # If GL015, reorder list for better results since very different fluxes
    if im2fit.experiment == 'GL015':
        fit.srcat = fit.srcat[::-1]
    fit.fit_all()
    fit.write_results()
    #fit.plot_all()
    print('DONE WITH ' + image + '.')
