package org.lsst.ccs.subsystem.rafts.fpga.compiler;

import java.math.BigDecimal;
import java.text.DecimalFormat;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.FPGARoutine;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Call;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Channel;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Function;
import org.lsst.ccs.subsystem.rafts.fpga.xml.FunctionPointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.RepeatFunctionPointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.RepeatSubroutinePointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Subroutine;
import org.lsst.ccs.subsystem.rafts.fpga.xml.SubroutinePointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Timeslice;

/**
 *
 * @author tonyj
 */
public class FPGA2ModelUtils {

    private final FPGA2Model model;
    private final Map<String, Object> parameterOverrides;

    public FPGA2ModelUtils(FPGA2Model model, Map<String, Object> parameterOverrides) {
        this.model = model;
        this.parameterOverrides = parameterOverrides;
    }

    public Function getFunctionForName(String name) {
        for (Function f : model.functionsMap.keySet()) {
            if (f.getId().equals(name)) {
                return f;
            }
        }
        throw new IllegalArgumentException("Function not found: " + name);
    }

    private FPGARoutine getRoutineForName(String name) {
        for (FPGA2Model.FPGARoutine f : model.routines) {
            if (f.orgRoutine.getId().equals(name)) {
                return f;
            }
        }
        throw new IllegalArgumentException("Routine not found: " + name);
    }

    public long getClockCyclesForFunction(String name) {
        Function f = getFunctionForName(name);
        LinkedList<Timeslice> timeslices = f.getTimeslices();
        Timeslice last = timeslices.getLast();
        if (last == null) { // Should never happen?
            return -1;
        } else {
            long total = 0;
            for (Timeslice slice : timeslices) {
                total += slice.getClockCycles();
            }
            // Plus 2, because the last slice duration has had 20ns (2 clock cycles) of overhead subtracted.
            // The default function appears to be a special case, not sure why
            return total; // + 2 + ("Default".equals(name) ? 0 : 1);
        }
    }

    public PixelCount getPixelsForFunction(String name) {
        Function f = getFunctionForName(name);
        LinkedList<Timeslice> timeslices = f.getTimeslices();
        long pixelCount = 0;
        boolean hasSOI = false;
        boolean hasEOI = false;
        boolean wasSet = false;
        for (Timeslice slice : timeslices) {
            boolean trgSet = false;
            for (Channel c : slice.getUpChannels()) {
                switch (c.getId()) {
                    case "TRG":
                        trgSet = true;
                        break;
                    case "SOI":
                        hasSOI = true;
                        break;
                    case "EOI":
                        hasEOI = true;
                }
            }
            if (trgSet != wasSet) {
                if (trgSet) {
                    pixelCount++;
                }
                wasSet = trgSet;
            }
        }
        return new PixelCount(pixelCount, hasSOI, hasEOI);
    }

    // TODO we could template all this, let's duplicate first

    private void visitCallListLong(List<Call> calls, BiConsumer<String, LongWithInfinity> visitFunction, BiConsumer<String, LongWithInfinity> visitRoutine) {
        for (Call call : calls) {
            LongWithInfinity repeat = new LongWithInfinity(call.getRepeatValue(), call.isInfinity());
            Function function = call.getFunction();
            FunctionPointer functionPointer = call.getFunctionPointer();
            Subroutine subroutine = call.getSubroutine();
            SubroutinePointer subroutinePointer = call.getSubroutinePointer();

            if (function != null || functionPointer != null) {
                String fcnName = call.getFcnName();
                if (functionPointer != null) {
                    if (parameterOverrides.containsKey(functionPointer.getId())) {
                        fcnName = parameterOverrides.get(functionPointer.getId()).toString();
                    } else {
                        fcnName = functionPointer.getFcnName();
                    }
                }
                RepeatFunctionPointer repeatFcnPtr = call.getRepeatFcnPtr();
                if (repeat.isZero()) {
                    if (parameterOverrides.containsKey(repeatFcnPtr.getId())) {
                        repeat = new LongWithInfinity((Integer) parameterOverrides.get(repeatFcnPtr.getId()));
                    } else {
                        repeat = new LongWithInfinity(repeatFcnPtr.getN());                        
                    }
                }
                visitFunction.accept(fcnName, repeat);
            } else if (subroutine != null || subroutinePointer != null) {
                String subName = call.getSubName();
                if (subroutinePointer != null) {
                    if (parameterOverrides.containsKey(subroutinePointer.getId())) {
                        subName = parameterOverrides.get(subroutinePointer.getId()).toString();
                    } else {
                        subName = subroutinePointer.getSubName();
                    }                }
                RepeatSubroutinePointer repeatSubPtr = call.getRepeatSubPtr();
                if (repeat.isZero()) {
                    if (parameterOverrides.containsKey(repeatSubPtr.getId())) {
                        repeat = new LongWithInfinity((Integer) parameterOverrides.get(repeatSubPtr.getId()));
                    } else {
                        repeat = new LongWithInfinity(repeatSubPtr.getN());
                    }                    
                }
                visitRoutine.accept(subName, repeat);
            }
        }

    }

    public LongWithInfinity getClockCyclesForRoutine(String name) {
        FPGARoutine f = getRoutineForName(name);
        LongWithInfinity[] totalClockCycles = { LongWithInfinity.ZERO };
        // Find all the call routines and functions, plus the reeat count
        Subroutine orgRoutine = (Subroutine) f.orgRoutine;
        List<Call> calls = orgRoutine.getCalls();
        visitCallListLong(calls, (fcnName, repeat) -> {
            totalClockCycles[0] = totalClockCycles[0].add(repeat.times(getClockCyclesForFunction(fcnName)));
        }, (subName, repeat) -> {
            totalClockCycles[0] = totalClockCycles[0].add(repeat.times(getClockCyclesForRoutine(subName)));
        });
        return totalClockCycles[0];
    }

    public LongWithInfinity getPixelsForRoutine(String name) {
        return getPixelsForRoutine(name, false);
    }

    private LongWithInfinity getPixelsForRoutine(String name, boolean soiActive) {
        FPGARoutine f = getRoutineForName(name);
        LongWithInfinity[] totalPixels = { LongWithInfinity.ZERO };
        boolean[] currentSoiActive = { soiActive };
        Subroutine orgRoutine = (Subroutine) f.orgRoutine;
        List<Call> calls = orgRoutine.getCalls();
        visitCallListLong(calls, (fcnName, repeat) -> {
            PixelCount pixelsForFunction = getPixelsForFunction(fcnName);
            currentSoiActive[0] |= pixelsForFunction.hasSOI;
            currentSoiActive[0] &= !pixelsForFunction.hasEOI;
            if (currentSoiActive[0]) {
                totalPixels[0] = totalPixels[0].add(repeat.times(pixelsForFunction.pixelCount));
            }
        }, (subName, repeat) -> {
            totalPixels[0] = totalPixels[0].add(repeat.times(getPixelsForRoutine(subName, currentSoiActive[0])));
        });        
        return totalPixels[0];
    }

    public static class LongWithInfinity {

        private static final LongWithInfinity ZERO = new LongWithInfinity(0);

        private final long value;
        private final boolean isInfinite;

        LongWithInfinity(long value, boolean isInfinite) {
            this.value = value;
            this.isInfinite = isInfinite;
        }

        public LongWithInfinity(long value) {
            this(value, false);
        }

        @Override
        public String toString() {
            return isInfinite ? "Infinity" : String.format("%,d", value);
        }

        public boolean isZero() {
            return !isInfinite && value == 0;
        }

        LongWithInfinity add(LongWithInfinity lwv) {
            return new LongWithInfinity(value + lwv.value, isInfinite || lwv.isInfinite);
        }

        LongWithInfinity times(LongWithInfinity lwv) {
            return new LongWithInfinity(value * lwv.value, isInfinite || lwv.isInfinite);
        }

        LongWithInfinity times(long multiplier) {
            return new LongWithInfinity(value * multiplier, isInfinite);
        }
        
        public FloatWithInfinity times(BigDecimal multiplier) {
            return new FloatWithInfinity(multiplier.multiply(BigDecimal.valueOf(value)), isInfinite);
        }

        long getValue() {
            return value;
        }
    }

    public static class FloatWithInfinity {

        //private static final FloatWithInfinity ZERO = new FloatWithInfinity(0);
        private static final DecimalFormat formatter = new DecimalFormat("#,##0.###");

        private final BigDecimal value;
        private final boolean isInfinite;

        FloatWithInfinity(BigDecimal value, boolean isInfinite) {
            this.value = value;
            this.isInfinite = isInfinite;
        }

        public FloatWithInfinity(BigDecimal value) {
            this(value, false);
        }

        @Override
        public String toString() {
            return isInfinite ? "Infinity" : formatter.format(value);
        }

        public boolean isZero() {
            return !isInfinite && value.compareTo(BigDecimal.ZERO) == 0;
        }

//        FloatWithInfinity add(FloatWithInfinity lwv) {
//            return new FloatWithInfinity(value + lwv.value, isInfinite || lwv.isInfinite);
//        }
//
//        FloatWithInfinity times(FloatWithInfinity lwv) {
//            return new FloatWithInfinity(value * lwv.value, isInfinite || lwv.isInfinite);
//        }
//
//        FloatWithInfinity times(float multiplier) {
//            return new FloatWithInfinity(value * multiplier, isInfinite);
//        }

        BigDecimal getValue() {
            return value;
        }
    }
    
    public static class PixelCount {

        private final long pixelCount;
        private final boolean hasSOI;
        private final boolean hasEOI;

        private PixelCount(long pixelCount, boolean hasSOI, boolean hasEOI) {
            this.pixelCount = pixelCount;
            this.hasSOI = hasSOI;
            this.hasEOI = hasEOI;
        }

        @Override
        public String toString() {
            return (hasSOI ? "SOI " : "") + (hasEOI ? "EOI " : "") + (pixelCount > 0 ? String.format("%,d", pixelCount) : "");
        }

    }
}
