package org.lsst.ccs.utilities.image;

import java.awt.Point;
import java.util.HashMap;
import java.util.Map;
import org.lsst.ccs.geometry.Geometry;
import org.lsst.ccs.utilities.ccd.CCD;
import org.lsst.ccs.utilities.ccd.CCDInterface;
import org.lsst.ccs.utilities.ccd.CCDType;
import org.lsst.ccs.utilities.ccd.Raft;
import org.lsst.ccs.utilities.ccd.SegmentInterface;

/**
 * Utility class for the generation of header specifications objects.
 *
 * @author turri
 */
public class FitsHeaderUtilities {

    /**
     * Utility function to return a pair of values in Fits standard.
     *
     * @param a Lower edge
     * @param b Upper edge
     * @return Returns the par of values in Fits Standard.
     *
     */
    private static String pair(String a, String b) {
        return "[" + a + "," + b + "]";
    }

    /**
     * Utility function to return a range of values in Fits standard.
     *
     * @param a Lower edge
     * @param b Upper edge
     * @return Returns the range in Fits Standard.
     *
     */
    private static String range(int a, int b) {
        return a + ":" + b;
    }

    /**
     * Utility function to return a range of values in Fits standard.
     *
     * @param a Lower edge
     * @param b Upper edge
     * @param flip boolean to indicate if the range should be flipped.
     * @return Returns the range in Fits Standard.
     *
     */
    private static String range(int a, int b, boolean flip) {
        return flip ? range(b, a) : range(a, b);
    }

    /**
     * Create the default ImageSet for the given CCDInterface.
     * 
     * @param ccd The CCD for which to create the ImageSet
     * @return The corresponding ImageSet
     */
    public static ImageSet createImageSetForCCD(CCDInterface ccd) {
        Map<String, Object> metaData = new HashMap<>();
        metaData.putAll(getCCDPrimaryHeaders(ccd));
        DefaultImageSet imageSet = new DefaultImageSet(metaData);
        
        for (SegmentInterface segment : ccd.getSegments()) {
           MetaDataSet metaDataSet = new MetaDataSet();
            final Map<String, Object> imageMetaData = getSegmentHeaders(segment);
            metaDataSet.addMetaData("channel", imageMetaData);            

            imageSet.addImage(ccd.getType().getSegmentTotalSerialSize(), ccd.getType().getSegmentTotalParallelSize(), metaDataSet);
        }
        return imageSet;
    }
    
    
    private static Geometry<?> findParentOfType(Geometry<?> geom, Class clazz) {
        Geometry parent = geom.getParent();
        if ( parent == null ) {
            return null;
        }
        if ( parent.getClass().isAssignableFrom(clazz) ) {
            return parent;
        }
        return findParentOfType(parent,clazz);
    }
    
    /**
     * Create fits headers for the primary header for the given Object.
     *
     * @param ccd The CCD for which the primary header is returned.
     * @return A map containing the fits header names and values.
     */
    public static Map<String, Object> getCCDPrimaryHeaders(CCDInterface ccd) {
        Map<String, Object> primaryMetaData = new HashMap<>();
        return primaryMetaData;
    }


    /**
     * Create fits headers for a Segment. Optionally create distorted
     * headers so that over and underscan regions can be viewed in DS9, as
     * requested by users at BNL.
     *
     * @param segment The segment for which to build the Header.
     * @return A map containing the fits header names and values.
     */
    public static Map<String, Object> getSegmentHeaders(SegmentInterface segment) {

        Map<String, Object> imageMetaData = new HashMap<>();

        imageMetaData.put("EXTNAME", String.format("Segment%01d%01d", getSegmentParallelPosition(segment), getSegmentSerialPosition(segment)));
        imageMetaData.put("CHANNEL", segment.getChannel());
        imageMetaData.put("CCDSUM", "1 1");

//        imageMetaData.put("CUNIT1", "pixel");
//        imageMetaData.put("CUNIT2", "pixel");

        CCDType type = segment.getCCDType();

        //Mosaic Keywords
        imageMetaData.put("DTM1_1", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        imageMetaData.put("DTM2_2", -1*getOneMinusTwoSx(segment));
        imageMetaData.put("DTV1", getDTV1(segment));
        imageMetaData.put("DTV2", getDTV2(segment));
        imageMetaData.put("DETSIZE", pair(range(1, 8*type.getDimh()), range(1, 2*type.getDimv())) );
        imageMetaData.put("DATASEC", pair(range(1+type.getPreh(), type.getPreh()+type.getDimh()), range(1, type.getDimv())));
        imageMetaData.put("DETSEC", getDETSEC(segment));
        

        //WCS keywords
        //Amplifier
        imageMetaData.put("PC1_2A", getOneMinusTwoSx(segment));
        imageMetaData.put("PC2_1A", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        //CCD
        imageMetaData.put("PC1_2C", getOneMinusTwoSx(segment));
        imageMetaData.put("PC2_1C", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        //Raft
        imageMetaData.put("PC1_2R", getOneMinusTwoSx(segment));
        imageMetaData.put("PC2_1R", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        //Focal Plane
        imageMetaData.put("PC1_2F", getOneMinusTwoSx(segment));
        imageMetaData.put("PC2_1F", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        //CCD Serial/Parallel
        imageMetaData.put("PC1_1B", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        imageMetaData.put("PC2_2B", getOneMinusTwoSx(segment));
        //Raft Serial/Parallel
        imageMetaData.put("PC1_1Q", segment.getCCDType().equals(CCDType.E2V) ? getOneMinusTwoSx(segment) : -1);
        imageMetaData.put("PC2_2Q", getOneMinusTwoSx(segment));

        //WCS CRVAL keywords
        //Amplifier
        imageMetaData.put("CRVAL1A", getCRVAL1A(segment));
        imageMetaData.put("CRVAL2A", getCRVAL2A(segment));
        //CCD
        imageMetaData.put("CRVAL1C", getCRVAL1C(segment));
        imageMetaData.put("CRVAL2C", getCRVAL2C(segment));
        //CCD Serial/Parallel
        imageMetaData.put("CRVAL1B", getCRVAL2C(segment));
        imageMetaData.put("CRVAL2B", getCRVAL1C(segment));

        CCD ccd = null;
        if ( segment instanceof Geometry) {
            ccd = (CCD) findParentOfType((Geometry)segment, CCD.class);
        }
        if ( ccd != null ) {
            //Raft
            imageMetaData.put("CRVAL1R", getCRVAL1R(segment,ccd));
            imageMetaData.put("CRVAL2R", getCRVAL2R(segment,ccd));
            //Raft Serial/Parallel
            imageMetaData.put("CRVAL1Q", getCRVAL2R(segment,ccd));
            imageMetaData.put("CRVAL2Q", getCRVAL1R(segment,ccd));
            Raft raft = (Raft) findParentOfType(ccd, Raft.class);
            if ( raft != null ) {
                //Focal Plane
                imageMetaData.put("CRVAL1F", getCRVAL1F(segment,ccd,raft));
                imageMetaData.put("CRVAL2F", getCRVAL2F(segment,ccd,raft));
            }
        }

        return imageMetaData;
    }
    
    private static int getSegmentParallelPosition(SegmentInterface seg) {
        return seg.getParallelPosition();
    }
    private static int getSegmentSerialPosition(SegmentInterface seg) {
        return seg.getSerialPosition();
    }
    
    private static boolean isReadoutDown(SegmentInterface seg) {
        return seg.isReadoutDown();            
    }
    private static boolean isReadoutLeft(SegmentInterface seg) {
        return seg.isReadoutLeft();
    }

    /**
     * Returns 1 - 2*Sx 
     * @param seg
     * @return 
     */
    private static double getOneMinusTwoSx(SegmentInterface seg) {
        return 1 - 2*seg.getParallelPosition();
    }

    private static double getCRVAL1A(SegmentInterface seg) {
        return seg.getParallelPosition()*(seg.getCCDType().getDimv()+1);        
    }
    
    private static double getCRVAL2A(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        double sx = seg.getParallelPosition(); 
        if ( type.equals(CCDType.E2V) ) {
            return  sx*(type.getDimh()+1) + (2*sx - 1)*type.getPreh();        
        } else {
            return type.getDimh() + 1 - type.getPreh();
        }
    }

    private static double getCRVAL1C(SegmentInterface seg) {
        return seg.getParallelPosition()*(2*seg.getCCDType().getDimv()+1);        
    }
    
    private static double getCRVAL2C(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        double sy = seg.getSerialPosition(); 
        return getCRVAL2A(seg) + sy*type.getDimh();        
    }

    private static double getCRVAL1R(SegmentInterface seg, CCDInterface ccd) {
        CCDType type = seg.getCCDType();
        return getCRVAL1C(seg) + type.getGap_outx() + 0.5*(type.getCCDpx()-type.getCCDax()) + ((Geometry)ccd).getParallelPosition()*
                (2*type.getDimv()+type.getGap_inx()+type.getCCDpx()-type.getCCDax());        
    }
    
    private static double getCRVAL2R(SegmentInterface seg, CCDInterface ccd) {
        CCDType type = seg.getCCDType();
        return getCRVAL2C(seg) + type.getGap_outy() + 0.5*(type.getCCDpy()-type.getCCDay()) + ((Geometry)ccd).getSerialPosition()*
                (8*type.getDimh()+type.getGap_iny()+type.getCCDpy()-type.getCCDay());        
    }

    private static double getCRVAL1F(SegmentInterface seg, CCDInterface ccd, Raft raft) {
        CCDType type = seg.getCCDType();
        return getCRVAL1R(seg,ccd) + raft.getParallelPosition()*type.getRaftx();        
    }
    
    private static double getCRVAL2F(SegmentInterface seg, CCDInterface ccd, Raft raft) {
        CCDType type = seg.getCCDType();
        return getCRVAL2R(seg,ccd) + raft.getSerialPosition()*type.getRafty();        
    }

    private static int getDTV1(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        int sx = type.equals(CCDType.E2V) ? seg.getParallelPosition() : 1;        
        return (type.getDimh() + 1 + 2*type.getPreh())*sx+seg.getSerialPosition()*type.getDimh()-type.getPreh();        
    }

    private static int getDTV2(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        return (1-seg.getParallelPosition())*(2*type.getDimv()+1);        
    }

    private static int getDSX1(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        int sx = seg.getParallelPosition();
        int sy = seg.getSerialPosition();
        if ( type.equals(CCDType.E2V) ) {
            int res = (sy+1)*type.getDimh()*sx;
            return res + (sy*type.getDimh()+1)*(1-sx);
        } else {
            return (sy+1)*type.getDimh();
        }
    }

    private static int getDSX2(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        int sx = seg.getParallelPosition();
        int sy = seg.getSerialPosition();
        if ( type.equals(CCDType.E2V) ) {
            int res = (sy*type.getDimh()+1)*sx;
            return res + (sy+1)*type.getDimh()*(1-sx);
        } else {
            return sy*type.getDimh()+1;
        }
    }
    
    private static int getDSY1(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        int sx = seg.getParallelPosition();
        return 2*type.getDimv()*(1-sx)+sx;
    }
    
    private static int getDSY2(SegmentInterface seg) {
        CCDType type = seg.getCCDType();
        int sx = seg.getParallelPosition();
        return (type.getDimv()+1)*(1-sx)+type.getDimv()*sx;
    }
    
    
    private static String getDETSEC(SegmentInterface segment) {
        return pair(range(getDSX1(segment),getDSX2(segment)), range(getDSY1(segment), getDSY2(segment)));
    }
    
}
