import numpy as np
import scipy as sp
from astropy.io import fits

class fitsimage(object):
    """ Class to manage a FITS file with metadata in a convenient way. """
    def __init__(self, path):
        # Set path and store output filenames
        self.path = path

        # load image data as 2d array
        self.data, self.fitsheader = fits.getdata(self.path, header=True)
        self.image = self.data[0,0,:,:]      
        self.residual_image = np.copy(self.image) # For model subtraction while fitting
        self.imageRMS = np.std(self.image[:,0:128]) # Use a sourcefree side of image for RMS estimate.
        print('Image RMS is {0} uJy/beam'.format(round(self.imageRMS*1e6,2)))

        # Read header info
        self.ra_npix = int(self.fitsheader['NAXIS1'])
        self.dec_npix = int(self.fitsheader['NAXIS2'])
        # Center point in pixels and deg
        self.cra_pix = int(self.fitsheader['CRPIX1'])
        self.cdec_pix = int(self.fitsheader['CRPIX2'])
        self.cra_deg = float(self.fitsheader['CRVAL1'])%360
        self.cdec_deg = float(self.fitsheader['CRVAL2'])%360
        # Cellsize
        self.dra = float(self.fitsheader['CDELT1'])
        self.ddec = float(self.fitsheader['CDELT2'])
        # Below we assume same cellsize in ra and dec.
        if not np.abs(self.ddec) == np.abs(self.dra):
            print(self.ddec, self.dra)
            print('Cellsize differ in dec and ra! Not supported.')
            sys.exit()

        # Observation meta info
        self.date = str(self.fitsheader['DATE-OBS'])
        self.experiment = str(self.fitsheader['OBSERVER'])
        self.freq = float(self.fitsheader['CRVAL3'])
        # Also get the band letter corresponding to this frequency
        bands = {'L':[1.e9,2.e9,1.6],'S':[2.e9,4.e9,2.3],'C':[4.e9,8.e9,6.9],
                 'X':[8.e9,12.e9,8.4],'U':[12.e9,18.e9,15],'K':[18.e9,27.e9,22],
                 'Q':[27.e9,45.e9,43],'W':[75.e9,110.e9,86]}
        for bd in bands.keys():
            if self.freq>bands[bd][0] and self.freq<bands[bd][1]:
               self.band = [bd,bands[bd][2]]
        
        if 'EAST' in self.path:
            self.nucleus = 'EAST'
        elif 'WEST' in self.path:
            self.nucleus = 'WEST'
        
        # Unit conversion factors
        dA = 77e6 # distance to Arp220 in pc, angular size distance
        dL = 80e6 # distance to Arp220 in pc, luminosity distance 
        pc2m = 3.08567758e16 # Convert from pc to m
        deg2as = 3600.0
        deg2mas = deg2as*1000.0
        self.jy2L = 1e-23*4*np.pi*(dL*pc2m*100)**2 # erg/s/Hz
        self.mas2pc = dA*1*1e-3/3600.0*np.pi/180 # with d in pc
        self.pix2mas = self.ddec * deg2mas
       
        self.read_beam()

    def read_beam(self):
        """ Function to get the beam from reading the FITS history."""
        # Get beam from history
        history = self.fitsheader['HISTORY']
        beamhist = ['CLEAN','BMAJ=','BMIN=','BPA=']
        for hline in history:
            hspl = hline.split()
            if len(filter(lambda x: x not in hspl, beamhist))==0:
                i,j,k = hspl.index('BMAJ='),hspl.index('BMIN='),hspl.index('BPA=')
                # Set beam FWHM major, minor and P.A.
                self.beam_deg = [float(hspl[i+1]),float(hspl[j+1]),float(hspl[k+1])]
        if not self.beam_deg:
            print('NO BEAM FOUND')
            sys.exit()
        # Calculate what beam is in pixels
        self.bmaj_pix = self.beam_deg[0]/self.ddec
        self.bmin_pix = self.beam_deg[1]/self.ddec
        self.beam_sigma_maj_pix = self.bmaj_pix/(2*np.sqrt(2*np.log(2)))
        self.beam_sigma_min_pix = self.bmin_pix/(2*np.sqrt(2*np.log(2)))
        # Set beam area, needed to get flux density of images
        # The beam area in pixels is calculated the same way as in AIPS.
        self.beamarea_pix = 2*np.pi *self.beam_sigma_maj_pix * self.beam_sigma_min_pix
        
    def ac2rc(self, ra, dec):
        """ Converts a pair of absolute coordinates (ra, dec) to
        offset coordinates relative to the center of this FITS image."""
        rr = (ra - self.cra_deg)*np.cos(dec*np.pi/180.0)
        rd = dec - self.cdec_deg
        return np.array([rr, rd])
    
    def ac2ap(self, ra, dec):
        """ Convert a pair of absolute coordinates, (ra, deg) to pixel values in this FITS image. This assumes
        the coordinate position is within the image."""
        # TODO: Check if outside image
        rc = self.ac2rc(ra, dec)
        acol = int(self.cra_pix + rc[0]/self.dra)
        arow = int(self.cdec_pix + rc[1]/self.ddec)
        return np.array([acol, arow])

    def get_stamp(self, ra_deg, dec_deg, nrow, ncol, mode = ''):
        """ Get a small stamp cutout around the position specified by absolute coordinates ra and dec. """
        ap = self.ac2ap(ra_deg, dec_deg)
        if ap[0]> self.ra_npix or ap[1]> self.dec_npix or ap[0]< ncol or ap[1] < nrow:
            raise ValueError("Position outside image. Cannot get stamp.")
        wcol = 0.5*ncol
        wrow = 0.5*nrow
        rl = int(ap[1] - wrow)
        ru = int(ap[1] + wrow)
        cl = int(ap[0] - wcol)
        cu = int(ap[0] + wcol)
        if mode =='residual':
            return self.residual_image[rl:ru, cl:cu]
        else:
            return self.image[rl:ru, cl:cu]

    def subtract_model_stamp(self, ra_deg, dec_deg, modelstamp):
        """ Remove a stamp of pixels from the big residual image. May cause slight edge effects."""
        ap = self.ac2ap(ra_deg, dec_deg)
        nrow, ncol = np.shape(modelstamp)
        wcol = 0.5*ncol
        wrow = 0.5*nrow
        rl = int(ap[1] - wrow)
        ru = int(ap[1] + wrow)
        cl = int(ap[0] - wcol)
        cu = int(ap[0] + wcol)
        self.residual_image[rl:ru, cl:cu] -= modelstamp
