package org.lsst.ccs.subsystem.rafts.fpga.compiler;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.EndSliceTime;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.FPGARoutine;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.SliceTime;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.SliceValues;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.StackOpCode;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.StackReturnOpcode;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.StackSubroutineOpCode;
import org.lsst.ccs.subsystem.rafts.fpga.compiler.FPGA2Model.StackSubroutineRepPtrOpCode;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Call;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Callable;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Channel;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Function;
import org.lsst.ccs.subsystem.rafts.fpga.xml.FunctionPointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Main;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Parameter;
import org.lsst.ccs.subsystem.rafts.fpga.xml.RepeatFunctionPointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.RepeatSubroutinePointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Sequencer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Subroutine;
import org.lsst.ccs.subsystem.rafts.fpga.xml.SubroutinePointer;
import org.lsst.ccs.subsystem.rafts.fpga.xml.Timeslice;

/**
 * The visitor which builds the final FPGA2Model
 *
 * @author aubourg
 */
class FPGA2ModelBuilderVisitor extends ModelBuilderVisitor {

    private static final Logger LOG = Logger.getLogger(FPGA2ModelBuilderVisitor.class.getName());

    private final FPGA2Model model = new FPGA2Model();
    private int nMains = 0;
    int curFuncIndex = 0;
    int nextFuncIndex = 1;
    int curFuncPtrIndex = 0;
    private final BigDecimal clockPeriodOverride;

    FPGA2ModelBuilderVisitor(BigDecimal clockPeriodOverride) {
        // Leave space for jump table
        FPGARoutine jumpTable = new FPGARoutine();
        jumpTable.orgRoutine = new JumpTable();
        model.routines.add(jumpTable);
        this.clockPeriodOverride = clockPeriodOverride;
    }

    public FPGA2Model getModel() {
        return model;
    }

    @Override
    public void visit(Function f) {
        if (f.getId().equals("Default")) {
            curFuncIndex = 0;
        } else {
            curFuncIndex = nextFuncIndex;
            nextFuncIndex++;
        }
        model.functionsMap.put(f, curFuncIndex);
        super.visit(f);
        if (sliceIndex < 16) {
            EndSliceTime t = new EndSliceTime(curFuncIndex, f, sliceIndex);
            model.timing.put(t.getAddress(), t);
        }
    }

    @Override
    public void visit(Timeslice s) {
        super.visit(s);
        SliceValues v = new SliceValues(curFuncIndex, (Function) current, s);
        model.timing.put(v.getAddress(), v);
        SliceTime t = new SliceTime(model, curFuncIndex, (Function) current, s);
        model.timing.put(t.getAddress(), t);
    }

    @Override
    public void visit(Subroutine s) {
        super.visit(s);
        FPGARoutine r = buildRoutine(s);
        model.routines.add(r);
        model.routinesMap.put(s, r);
    }

    @Override
    public void visit(Main m) {
        super.visit(m);

        // We put the jump table first, followed by
        // mains at the beginning of the list of routines
        // but in the order they are found in the file. 
        FPGARoutine r = buildRoutine(m);
        model.routines.add(1 + nMains++, r);
        model.routinesMap.put(m, r);
    }

    @Override
    public void visit(FunctionPointer fp) {
        model.functionsPtrMap.put(fp, curFuncPtrIndex);
        curFuncPtrIndex++;
        super.visit(fp);
    }
    int curSubPtrIndex = 0;

    @Override
    public void visit(SubroutinePointer sp) {
        model.routinesPtrMap.put(sp, curSubPtrIndex);
        curSubPtrIndex++;
        super.visit(sp);
    }
    int curRepFuncPtrIndex = 0;

    @Override
    public void visit(RepeatFunctionPointer rfp) {
        model.functionsRepPtrMap.put(rfp, curRepFuncPtrIndex);
        curRepFuncPtrIndex++;
        super.visit(rfp);
    }
    int curRepSubPtrIndex = 0;

    @Override
    public void visit(RepeatSubroutinePointer rsp) {
        model.routinesRepPtrMap.put(rsp, curRepSubPtrIndex);
        curRepSubPtrIndex++;
        super.visit(rsp);
    }

    private FPGARoutine buildRoutine(Callable c) {
        FPGARoutine r = new FPGARoutine();
        r.orgRoutine = c;
        List<Call> calls;
        if (c instanceof Subroutine) {
            calls = ((Subroutine) c).getCalls();
        } else if (c instanceof EmbeddedSubroutine) {
            calls = ((EmbeddedSubroutine) c).getCalls();
        } else {
            throw new RuntimeException("type mismatch for " + c + " should be Subroutine or Embedded");
        }
        calls.stream().forEach((org.lsst.ccs.subsystem.rafts.fpga.xml.Call call) -> {
            if (call.getFunction() != null) {
                if (call.getRepeatFcnPtr() == null) {
                    StackOpCode op = model.new StackFunctionOpCode(call, r);
                    r.add(op);
                } else {
                    StackOpCode op = model.new StackFunctionRepPtrOpCode(call, r);
                    r.add(op);
                }
            } else if (call.getFunctionPointer() != null) {
                if (call.getRepeatFcnPtr() == null) {
                    StackOpCode op = model.new StackFunctionPtrOpCode(call, r);
                    r.add(op);
                } else {
                    StackOpCode op = model.new StackFunctionPtrRepPtrOpCode(call, r);
                    r.add(op);
                }
            } else if (call.getSubroutine() != null) {
                if (call.getRepeatSubPtr() == null) {
                    StackOpCode op = new StackSubroutineOpCode(call, r, null);
                    // we will find later the correct FPGARoutine to insert
                    r.add(op);
                } else {
                    StackOpCode op = model.new StackSubroutineRepPtrOpCode(call, r, null);
                    // we will find later the correct FPGARoutine to insert
                    r.add(op);
                }
            } else if (call.getSubroutinePointer() != null) {
                if (call.getRepeatSubPtr() == null) {
                    StackOpCode op = model.new StackSubroutinePtrOpCode(call, r);
                    r.add(op);
                } else {
                    StackOpCode op = model.new StackSubroutinePtrRepPtrOpCode(call, r);
                    r.add(op);
                }
            } else {
                Subroutine parent;
                if (c instanceof Subroutine) {
                    parent = (Subroutine) c;
                } else {
                    parent = ((EmbeddedSubroutine) c).getParent();
                }
                EmbeddedSubroutine es = new EmbeddedSubroutine(parent);
                es.setCalls(call.getCalls());
                FPGARoutine er = buildRoutine(es);
                model.routines.add(er);
                if (call.getRepeatSubPtr() == null) {
                    StackOpCode op = new StackSubroutineOpCode(call, r, er);
                    r.add(op);
                } else {
                    StackOpCode op = model.new StackSubroutineRepPtrOpCode(call, r, er);
                    r.add(op);
                }
            }
        });
        r.add(new StackReturnOpcode(r));
        return r;
    }

    private void completeRoutines() {
        // first, find all missing FPGARoutines
        model.routines.stream().forEach((FPGARoutine r) -> {
            r.opcodes.stream().forEach((StackOpCode op) -> {
                if (op instanceof StackSubroutineOpCode) {
                    StackSubroutineOpCode sop = (StackSubroutineOpCode) op;
                    if (sop.callee == null) {
                        FPGARoutine rr = model.routinesMap.get(sop.call.getSubroutine());
                        sop.callee = rr;
                    }
                } else if (op instanceof StackSubroutineRepPtrOpCode) {
                    StackSubroutineRepPtrOpCode sop = (StackSubroutineRepPtrOpCode) op;
                    if (sop.callee == null) {
                        FPGARoutine rr = model.routinesMap.get(sop.call.getSubroutine());
                        sop.callee = rr;
                    }
                }
            });
        });
        int addr = JumpTable.MAX_ENTRIES*JumpTable.ENTRY_SIZE; // reserve space for jump table 32 entries, with 4 words each
        for (FPGARoutine r : model.routines.subList(1, model.routines.size())) {
            r.baseAddress = addr;
            addr += r.opcodes.size();
            // round up : THIS IS AN LPNHE REQUEST. They like 0-ending hex addresses.
            // please keep this.
            addr = ((addr + 16) / 16) * 16;
        }
        model.lastAddr = addr - 1;
        // Finally, fill in the jump table
        FPGARoutine jumpTable = model.routines.get(0);
        Map<Subroutine, FPGARoutine> routinesMap = model.routinesMap;
        int i = 0;
        for (Map.Entry<Subroutine, FPGARoutine> entry : routinesMap.entrySet()) {
            if (entry.getKey() instanceof Main) {
                Call call = new Call(entry.getKey());
                call.setRepeatValue(1);
                // Per instructions, each entry in the jump table should use 4 words 
                // (although it is not clear why)
                jumpTable.add(new StackSubroutineOpCode(call, jumpTable, entry.getValue()));
                for (int j=1; j<JumpTable.ENTRY_SIZE; j++) {
                    jumpTable.add(new StackReturnOpcode(jumpTable));
                }
            }
            if (i++ >= JumpTable.MAX_ENTRIES) {
                break;
            }
        }
    }
    
    private void completeFunctions() {
        // Special handling for the first/last slice.
        // The first slice will be 10ns longer than the value written to the register
        // The last slice will be 20ns .longer than the value written to the register
        // In addition all functions except the default function must have at least two slices
        BigDecimal clockPeriod = model.getClockPeriod();
        for (Function f : model.functionsMap.keySet()) {
            
            BigDecimal nanos = BigDecimal.ZERO;
            if (!"Default".equals(f.getId())) {
                List<Timeslice> slices = f.getTimeslices();
                if (slices.size() < 2) {
                    throw new RuntimeException("All functions except the default function must have at least two time slices");
                }

                Timeslice firstSlice = slices.get(0);
                if (firstSlice.getDurationNanos().compareTo(clockPeriod) < 0) {
                    throw new RuntimeException("First slice in function must be >= clockperiod");
                }

                Timeslice lastSlice = slices.get(slices.size() - 1);
                if (lastSlice.getDurationNanos().compareTo(clockPeriod.multiply(BigDecimal.valueOf(2))) < 0) {
                    throw new RuntimeException("Last timeslice of a function must be >= 2 * clockperiod");
                }
                for (Timeslice slice : slices) {
                    slice.setStartNanos(nanos);
                    boolean isLast = slice == lastSlice;
                    int clockCycles = slice.computeClockCycles(clockPeriod, isLast, false);
                    final BigDecimal actualDuration = clockPeriod.multiply(BigDecimal.valueOf(clockCycles));
                    slice.setActualDurationNanos(actualDuration);
                    nanos = nanos.add(actualDuration);
                }
                //firstSlice.setDurationNanos(firstSlice.getDurationNanos() - clockPeriod);
                //lastSlice.setDurationNanos(lastSlice.getDurationNanos() - 2 * clockPeriod);
            } else {
                List<Timeslice> slices = f.getTimeslices();
                for (Timeslice slice : slices) {
                    slice.setStartNanos(nanos);
                    int clockCycles = slice.computeClockCycles(clockPeriod, false, true);
                    final BigDecimal actualDuration = clockPeriod.multiply(BigDecimal.valueOf(clockCycles));
                    slice.setActualDurationNanos(actualDuration);
                    nanos = nanos.add(actualDuration);
                }         
            }
        }
    }


    @Override
    public void visit(Sequencer s) {
        super.visit(s);
        updateMetadata(s);
        completeFunctions();
        completeRoutines();
        updateChannels(s);
    }

    private void updateMetadata(Sequencer s) {
        List<Parameter> params = s.getSequencerConfig().getParameters();
        params.stream().forEach((p) -> {
            model.addMetadata(p.getId(), p.getValue());
        });
        String d = model.getMetadata().get("clockperiod");
        BigDecimal cp;
        if (d==null) {
            cp = BigDecimal.TEN;
        } else {
            cp = parseNanos(d);
        }
        if (clockPeriodOverride != null && !cp.equals(clockPeriodOverride)) 
        {
            LOG.log(Level.WARNING, "Clock period specified in sequencer file {0} overriden to {1}", new Object[]{cp, clockPeriodOverride});
            cp = clockPeriodOverride;
            model.addMetadata("clockperiod", cp.toString()+" ns");
        }
        model.setClockPeriod(cp);
    }

    private void updateChannels(Sequencer s) {
        List<Channel> channels = s.getSequencerConfig().getChannels();
        channels.stream().forEach((c) -> {
            model.addChannel(c.getId(), c.getValue());
        });
    }

}
