package org.lsst.ccs.utilities.image;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.Closeable;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import nom.tam.fits.BasicHDU;
import nom.tam.fits.BinaryTableHDU;
import nom.tam.fits.Fits;
import nom.tam.fits.FitsException;
import nom.tam.fits.FitsFactory;
import nom.tam.fits.FitsUtil;
import nom.tam.fits.Header;
import nom.tam.fits.HeaderCardException;
import nom.tam.util.BufferedFile;
import static org.lsst.ccs.utilities.image.HeaderSpecification.DataType.Float;
import org.lsst.ccs.utilities.image.HeaderSpecification.HeaderLine;
import org.lsst.ccs.utilities.image.ImageSet.Image;

/**
 * A utility for writing FITS files following LSST conventions.
 *
 * @author tonyj
 */
public class FitsFileWriter implements Closeable {

    private final BufferedFile bf;
    // The current write position for each extended HDU
    private final long[] position0, position;
    private BasicHDU image_hdu[] = new BasicHDU[16];

    public enum BitsPerPixel {

        /**
         * 16 bits per pixel
         */
        BIT16,
        /**
         * 32 bits per pixel
         */
        BIT32
    };

    public FitsFileWriter() {
        bf = null;
        position0 = null;
        position = null;
    }

    /**
     * Open an LSST FITS file for writing a CCD ImageSet.
     *
     * @param file The file to write to
     * @param images The ImageSet to write to the file. Note that this specifies
     * the images to write, not the actual data for the images.
     * @param metaData The meta-data maps to use to extract header info from
     * @param config The configuration which controls how meta data is written
     * to the file
     * @param bits The number of bits per pixel for images
     * @throws java.io.IOException
     * @throws nom.tam.fits.FitsException
     */
    public FitsFileWriter(File file, ImageSet images, Map<String, Map<String, Object>> metaData, Map<String, HeaderSpecification> config, BitsPerPixel bits) throws IOException, FitsException {

        position0 = new long[images.getImages().size() + 1];
        position = new long[images.getImages().size()];
        int[][] intDummyData = new int[1][1];
        short[][] shortDummyData = new short[1][1];
        Object[] tableDummyData = new Object[0];
        Object dummyData = bits == BitsPerPixel.BIT16 ? shortDummyData : intDummyData;
        bf = new BufferedFile(file, "rw");

        FitsFactory.setUseHierarch(true);

        // Create primary header
        BasicHDU primary = BasicHDU.getDummyHDU();
        addMetaDataToHeader(primary, "primary", metaData, config);
        primary.getHeader().write(bf);

        //Create extension image headers for each image, and reserve space for data
        int i = 0;
        for (Image image : images.getImages()) {
            BasicHDU hdu = FitsFactory.HDUFactory(dummyData);
            image_hdu[i] = hdu;
            addMetaDataToHeader(hdu, "extended", image.getMetaData(), config);
            if (bits == BitsPerPixel.BIT16) {
                // To store as unsigned 16 bit values, we have to set BZERO to 32768, and 
                // subtract 32768 from each value. For now we have cheated, and stored them
                // (incorrectly) as signed 16 but values.
                // See: http://heasarc.gsfc.nasa.gov/docs/software/fitsio/c/c_user/node23.html
                hdu.addValue("BSCALE", 1.0, "Unsigned 16 bit data");
                hdu.addValue("BZERO", 32768, "Unsigned 16 bit data");
            }
            Header header = hdu.getHeader();
            header.setXtension("IMAGE");
            header.setNaxis(1, image.getWidth());
            header.setNaxis(2, image.getHeight());
            position0[i] = bf.getFilePointer();
            header.write(bf);
            position[i] = bf.getFilePointer();
            long imageSize = (bits == BitsPerPixel.BIT16 ? 2l : 4l) * image.getWidth() * image.getHeight();
            bf.seek(bf.getFilePointer() + imageSize);
            FitsUtil.pad(bf, imageSize);
            i++;
        }
        position0[i] = bf.getFilePointer();

        // If necessary, create any additional extended HDU's here.
        // Create any extra BinTables from the specification
        FitsFactory.setUseAsciiTables(false);
        for (String key : config.keySet()) {
            if (!"primary".equals(key) && !"extended".equals(key)) {
                BasicHDU binary = FitsFactory.HDUFactory(tableDummyData);
                addMetaDataToHeader(binary, key, metaData, config);
                Header header = binary.getHeader();
                header.setXtension("BINTABLE");
                header.write(bf);
            }
        }
    }

    public BasicHDU[] getImage_hdu() {
        return image_hdu;
    }

    
    
    /**
     * Write the actual image data to the file. It is not necessary that all of
     * the data for the image be available at once, this method will write
     * whatever data is currently available in the byte buffer to the specified
     * image, and will keep track of how much has been written to each image to
     * allow more data to be written later. This method assumes the data is
     * given in the order it is to be written to the file. If any data
     * reordering is needed it needs to be done before calling this method.
     *
     * @param imageIndex The image to which this data is to be written
     * @param src The image
     * @throws IOException If an IOException is generated, or if more data is
     * sent than was expected for a particular image.
     */
    public void write(int imageIndex, ByteBuffer src) throws IOException {
        int length = src.remaining();
        if (length + position[imageIndex] > position0[imageIndex + 1]) {
            throw new IOException("Too much data written for image: " + imageIndex);
        }
        bf.seek(position[imageIndex]);
        if (src.hasArray()) {
            bf.write(src.array(), src.arrayOffset() + src.position(), src.remaining());
            src.position(src.limit());
        } else {
            while (src.remaining() > 0) {
                bf.write(src.get());
            }
        }
        position[imageIndex] += length;
    }

    @Override
    public void close() throws IOException {
        bf.close();
    }

    public void addMetaDataToHeader(int image_index, String specName, Map<String, Map<String, Object>> metaData, Map<String, HeaderSpecification> config) throws HeaderCardException, IOException, FitsException {
            bf.seek(position0[image_index]);
            addMetaDataToHeader(getImage_hdu()[image_index],specName,metaData,config);
            getImage_hdu()[image_index].write(bf);
    }
    
    private void addMetaDataToHeader(BasicHDU hdu, String specName, Map<String, Map<String, Object>> metaData, Map<String, HeaderSpecification> config) throws HeaderCardException, IOException {
        HeaderSpecification spec = config.get(specName);
        if (spec == null) {
            throw new IOException("Missing specification for header: " + specName);
        }
        for (HeaderLine header : spec.getHeaders()) {
            Object value = header.getValue(metaData);
            try {
                if (value != null) {
                    switch (header.getDataType()) {
                        case Integer:
                            hdu.addValue(header.getKeyword(), ((Number) value).intValue(), header.getComment());
                            break;
                        case Float:
                            hdu.addValue(header.getKeyword(), ((Number) value).doubleValue(), header.getComment());
                            break;
                        case Boolean:
                            hdu.addValue(header.getKeyword(), (Boolean) value, header.getComment());
                            break;
                        default:
                            hdu.addValue(header.getKeyword(), String.valueOf(value), header.getComment());
                    }
                }
            } catch (ClassCastException x) {
                throw new IOException(String.format("Meta-data header %s with value %s(%s) cannot be converted to type %s", header.getKeyword(), value, value.getClass(), header.getDataType()));
            }
        }
    }

    public void addBinaryTable(String pddatfile, String fitsfile, String extnam, String c1name, String c2name, Double tstart) {

// load the data to be put in the binary table
        FileReader fstream = null;
        File pdFl = new File(pddatfile);
        
        double pdavg_all =  0.0;
        double pdavg_expo = 0.0;
        
        try {
            if (!pdFl.exists()) {
                System.out.println("Cannot find the input file of PD values");
            }
        } catch (Exception e) {
            System.out.println("Failed to verify existence of file" + e);
        }
        try {
            fstream = new FileReader(pdFl);
        } catch (FileNotFoundException e) {
            System.out.println("Failed to open reader stream for file (" + pddatfile + ") for reason " + e);
        }
        List<Double> pddata = new ArrayList<>();
        List<Double> tmdata = new ArrayList<>();

        try {
            if (fstream != null) {
                double pdsum = 0.0;
                System.out.println("reading file of buffered values from either Bias or PhotoDiode device");
                BufferedReader in = new BufferedReader(fstream);
                while (in.ready()) {
                    String line = in.readLine();
                    double pdtime = Double.valueOf(line.split(" ")[0]);
                    double pdval = Double.valueOf(line.split(" ")[1]);
//                    System.out.println("time = " + pdtime + " , pdval = " + pdval);
                    tmdata.add(pdtime);
                    pddata.add(pdval);
                    pdsum += pdval;
                }
                in.close();
                if (pddata.size()>0) {
                    pdavg_all = pdsum / (double)pddata.size();
                }
                System.out.println("Average of all PD values is "+pdavg_all);
                pdsum = 0.0;
                int nelem = 0;
                for (Iterator<Double> it = pddata.iterator(); it.hasNext();) {
                    double pdval = it.next();
                    if (Math.abs(pdval)>Math.abs(pdavg_all)) {
                        pdsum += pdval;
                        nelem++;
                    }
                }
                if (nelem>0) pdavg_expo = pdsum / (double) nelem;
                System.out.println("Number elements in luminous curve region = "+nelem);
                System.out.println("Sum of PD values in luminous region = "+pdsum);
                System.out.println("PD average value during exposure = "+pdavg_expo);
            }
        } catch (IOException ee) {
            System.out.println("Failed to read input PD data" + ee);
        }
//        Object[] tab = new Object[]{tmdata, pddata};
        BinaryTableHDU bhdu = null;
        System.out.println("Trying to open FITS file " + fitsfile);
        Fits f = null, f1 = null;
        try {
            f = new Fits(fitsfile);
            f1 = new Fits();
        } catch (FitsException ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        }

        try {
            f.read();
            System.out.println("Number of HDUs = " + f.getNumberOfHDUs());
            int ihdu;
            int bhduidx = 0;
            f1.addHDU(f.getHDU(0)); // add the primary header to the new FITS instance
            pdavg_expo /= 1.e-9; // convert to nA
            System.out.println("Setting MONDIODE value in primary header to "+pdavg_expo);
            if (f1.getHDU(0).getHeader().containsKey("MONDIODE")) f1.getHDU(0).getHeader().deleteKey("MONDIODE");
            f1.getHDU(0).getHeader().addValue("MONDIODE",pdavg_expo,"avg PD value (nA) during expo");
            
            // find the hdu and start filling the new fits
            for (ihdu = 1; ihdu < f.getNumberOfHDUs(); ihdu++) {
                String extn = f.getHDU(ihdu).getHeader().getStringValue("EXTNAME");
                System.out.println("iHDU = " + ihdu + " EXTNAME = " + extn + " searching for " + extnam);
                if (extn.contains(extnam)) {
                    bhdu = (BinaryTableHDU) f.getHDU(ihdu);
                    bhduidx = ihdu;
                    System.out.println("Found the extension to be modified.");
                    break;
                } else {
                    f1.addHDU(f.getHDU(ihdu));
                }
            }
            if (bhduidx == 0) {
                throw new Exception("Failed to find the requested extension " + extnam);
            }
            double[] tmdbl = new double[tmdata.size()];
            int i = 0;
            for (Double tmd : tmdata) {
                tmdbl[i++] = tmd;
            }
            double[] pddbl = new double[pddata.size()];
            i = 0;
            for (Double pdd : pddata) {
                pddbl[i++] = pdd;
            }

            bhdu.addColumn(tmdbl);
            bhdu.addColumn(pddbl);

            bhdu.getHeader().setNaxis(1, 16);
            bhdu.getHeader().setNaxis(2, pddata.size());
            bhdu.getHeader().addValue("TFIELDS", 2, " ");
            bhdu.setColumnName(0, c1name, "time");
            bhdu.setColumnName(1, c2name, "values");

            System.out.println("Number of columns = " + bhdu.getNCols());
            System.out.println("Number of rows = " + bhdu.getNRows());
// Show what we have made
            bhdu.info();

            bhdu.getHeader().addValue("TSTART", tstart, "Time of Start of Readings");
// send new fits to new file
            BufferedFile bfnew = new BufferedFile(fitsfile + "-new", "rw");
// add the replacement hdu
            f1.addHDU(bhdu);
// add the remaining hdus            
            for (ihdu = bhduidx + 1; ihdu < f.getNumberOfHDUs(); ihdu++) {
                f1.addHDU(f.getHDU(ihdu));
            }
            // write it all to the new file
            f1.write(bfnew);
            //
// flush it all out
            bfnew.flush();
// thats it
            bfnew.close();
// replace the original fits file
            Runtime r = Runtime.getRuntime();
            Process p = r.exec("mv " + fitsfile + "-new " + fitsfile);

        } catch (FitsException ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        } catch (IOException ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        } catch (Exception ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

//--------------------------------------------------------------
    public double getFluxStats(String fitsfile) {

        Double minflux = Double.MAX_VALUE;
        Double fluxsum = 0.0;
        Double fluxentries = 0.0;
        BinaryTableHDU bhdu = null;
        System.out.println("Trying to open FITS file " + fitsfile);
        Fits f = null;
        try {
            f = new Fits(fitsfile);
        } catch (FitsException ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        }

        try {
            f.read();
            System.out.println("Number of HDUs = " + f.getNumberOfHDUs());
            int ihdu;
            int bhduidx = 0;
            double gain = f.getHDU(0).getHeader().getDoubleValue("CCDGAIN");
            double exptime = f.getHDU(0).getHeader().getDoubleValue("EXPTIME");

            // find the hdu and start filling the new fits
            for (ihdu = 1; ihdu < f.getNumberOfHDUs(); ihdu++) {
                String extn = f.getHDU(ihdu).getHeader().getStringValue("EXTNAME");
                System.out.println("iHDU = " + ihdu + " EXTNAME = " + extn + " searching for Segment");
                if (extn.contains("Segment")) {
                    double avg = f.getHDU(ihdu).getHeader().getDoubleValue("AVWOBIAS"); // data region only
                    double bias = f.getHDU(ihdu).getHeader().getDoubleValue("AVGBIAS"); // bias region only
                    double signal = (avg-bias) * gain;         // in electrons
                    double flux = Math.max(signal / exptime, 1.0);  // in electrons/second (can't be < 0 !)
                    minflux = Math.min(minflux, flux);
                    fluxsum += flux;
                    fluxentries += 1.0;
                    System.out.println(extn + " signal = " + signal + " exptime = " + exptime + " flux = " + flux + " minflux for CCD = " + minflux);
                }
            }

        } catch (FitsException ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        } catch (IOException ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        } catch (Exception ex) {
            Logger.getLogger(FitsFileWriter.class.getName()).log(Level.SEVERE, null, ex);
        }
        
        return (fluxsum/fluxentries);
    }
    
}
