import glob
import sys, os
import astropy.io.fits as pf
import numpy as np
import pylab as pl

# Read which nucleus we are working on, i.e. which set of images
try: 
    nucleus = sys.argv[1]
except:
    print 'Error: Must specify nucleus as EAST or WEST as first input!'
    sys.exit()

try: 
    band = sys.argv[2]
except:
    print 'Error: Must specify BAND, e.g. C as second input!'
    sys.exit()

if not (nucleus == 'EAST' or nucleus == 'WEST'):
    print 'Error: Must specify nucleus as EAST or WEST!'
    print '2'
    sys.exit()


# Make list of input FITS images to read
path = '..data/*/*_' + band + '_' + nucleus + '_IMAGE.FITS'
fits = glob.glob(path)

# Define reference epoch for coordinate system etc.
goodone = 'BB297A'
for i,f in enumerate(fits):
    if goodone in f:
        fits[0],fits[i] = fits[i],fits[0]
print 'Reading files: ' 
print fits

outfile = nucleus + '_' + band + '_' + goodone + 'REF_STACKED.fits'

# Stack data
imdatas = [pf.getdata(image) for image in fits]
wgts = []
for im in imdatas:
    print np.shape(im)
    rms = np.std(im[0,0,:,0:2000])
    wgts.append(1/(rms**2))
stacked_imdata = np.average(imdatas, axis=0, weights = wgts)
print 'EXPECTED WEIGHTED NOISE = ', np.sqrt(1/np.sum(wgts))

# SAVE RESULT
header = pf.getheader(fits[0])
pf.writeto(outfile, stacked_imdata, header, clobber=True)
