package org.lsst.ccs.drivers.auxelex;

import java.io.IOException;
import java.util.Collection;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.SocketTimeoutException;
import java.util.HashMap;
import java.util.Map;
import org.lsst.ccs.drivers.commons.DriverException;
import org.lsst.ccs.drivers.commons.DriverTimeoutException;
import org.lsst.ccs.utilities.conv.Convert;

/**
 *  Routines for implementing the SLAC register protocol
 *
 *  @author  Owen Saxton
 */
public class Srp {

    /**
     *  Constants and data.
     */
    public static final int
        DEFAULT_PORT = 8192;
    public enum BoardType {UNKNOWN, REB_PS_PROTO, REB_PS_UPDATE, REB_PS_PROD, PDU_5V,
                           PDU_24V_DIRTY, PDU_24V_CLEAN, PDU_48V, BFR, ION_PUMP, HEATER}

    private static final int
        MAX_REGS = 64,
        MAX_REGS_3 = 1024,

        READ_TIMEOUT = 1000,

        SRP_PKT_LENG    = 20,
        SRP_OFF_HEADER  = 4,
        SRP_OFF_OC      = 8,
        SRP_OFF_ADDR    = 8,
        SRP_OFF_DATA    = 12,
        SRP_OFF_FOOTER  = 16,
        SRP_STS_TIMEOUT = 0x02,
        SRP_STS_ERROR   = 0x01,

        SRP3_PKT_LENG    = 20,
        SRP3_OFF_HEADER  = 0,
        SRP3_OFF_VERSION = 0,
        SRP3_OFF_OPCODE  = 1,
        SRP3_OFF_TIMEOUT = 3,
        SRP3_OFF_TID     = 4,
        SRP3_OFF_ADDR_LO = 8,
        SRP3_OFF_ADDR_HI = 12,
        SRP3_OFF_SIZE    = 16,
        SRP3_OFF_DATA    = 20,
        SRP3_OFF_FOOTER  = 20,
        SRP3_OPC_NP_READ  = 0,
        SRP3_OPC_NP_WRITE = 1,
        SRP3_OPC_P_WRITE  = 2,
        SRP3_STS_MEMORY   = 0x00ff,
        SRP3_STS_TIMEOUT  = 0x0100,
        SRP3_STS_EOFE     = 0x0200,
        SRP3_STS_FRAMING  = 0x0400,
        SRP3_STS_VERSION  = 0x0800,
        SRP3_STS_REQUEST  = 0x1000,

        ADDR_REB_PS_PROTO_0 = 20,
        ADDR_REB_PS_PROTO_1 = 21,
        ADDR_REB_PS_UPDATE = 22,
        ADDR_REB_PS_PROD_MIN = 39,
        ADDR_REB_PS_PROD_MAX = 62,
        ADDR_REB_PS_PROD_EXTRA = 76,
        ADDR_PDU_5V = 68,
        ADDR_PDU_24V_DIRTY = 66,
        ADDR_PDU_24V_CLEAN = 67,
        ADDR_PDU_48V = 65,
        ADDR_BFR = 63,
        ADDR_ION_PUMP = 0,
        ADDR_HEATER = 0;

    private static final byte
        SRP_OC_READ  = 0x00,
        SRP_OC_WRITE = 0x40;

    private BoardType boardType = BoardType.UNKNOWN;
    private DatagramSocket sock;
    private int seqno, nSeqErr, nTimeout;
    private int srpVersion = 3, maxRegs = MAX_REGS_3;
    private byte[]
        inBuff = new byte[SRP3_PKT_LENG + 4 * (maxRegs - 1)],
        outBuff = new byte[SRP3_PKT_LENG + 4 * (maxRegs - 1)];
    private DatagramPacket
        inPkt = new DatagramPacket(inBuff, inBuff.length),
        outPkt = new DatagramPacket(outBuff, outBuff.length);
    private byte[] ipAddr = {0, 0, 0, 0};
    private boolean simulated, debug;
    private final Map<Integer, Integer> simRegMap = new HashMap<>();
    private Collection<BoardType> validBoardTypes;
    private int probeAddress;


    /**
     *  Sets the collection of valid board types.
     *
     *  @param  types  The collection of valid types
     */
    public void setValidBoardTypes(Collection types)
    {
        validBoardTypes = types;
    }


    /**
     *  Sets the address to be probed at open time.
     *
     *  @param  addr  The address to be probed, or -1 to prevent probing
     */
    public void setProbeAddress(int addr)
    {
        probeAddress = addr;
    }


    /**
     *  Sets the SRP protocol version.
     *
     *  @param  version  The protocol version
     */
    public void setSrpVersion(int version)
    {
        if (version == srpVersion) return;
        srpVersion = version;
        maxRegs = srpVersion == 3 ? MAX_REGS_3 : MAX_REGS;
        int pktLeng = srpVersion == 3 ? SRP3_PKT_LENG : SRP_PKT_LENG;
        inBuff = new byte[pktLeng + 4 * (maxRegs - 1)];
        outBuff = new byte[pktLeng + 4 * (maxRegs - 1)];
        inPkt = new DatagramPacket(inBuff, inBuff.length);
        outPkt = new DatagramPacket(outBuff, outBuff.length);
    }


    /**
     *  Opens a connection to a board.
     *
     *  @param  host  The host name or IP address, or null or empty for simulation
     *  @param  port  The port number
     *  @throws  DriverException
     */
    public synchronized void open(String host, int port) throws DriverException
    {
        if (sock != null) {
            throw new DriverException("Connection is already open");
        }
        nSeqErr = 0;
        nTimeout = 0;
        try {
            DatagramSocket newSock = new DatagramSocket();
            if (host == null || host.isEmpty()) {
                simulated = true;
                simInitialize();
            }
            else {
                simulated = false;
                int actPort = port == 0 ? DEFAULT_PORT : port;
                InetAddress inetAddr = InetAddress.getByName(host);
                newSock.connect(inetAddr, actPort);
                newSock.setSoTimeout(READ_TIMEOUT);
                outPkt.setAddress(inetAddr);
                outPkt.setPort(actPort);
                ipAddr = inetAddr.getAddress();
                setBoardType();
//                if (probeAddress != -1) {
//                    readReg(probeAddress);
//                }
            }
            sock = newSock;
        }
        catch (IOException e) {
            throw new DriverException(e);
        }
    }


    /**
     *  Opens a connection to a board.
     *
     *  @param  ipAddr  The IP address, or null or empty for simulation
     *  @throws  DriverException
     */
    public void open(String ipAddr) throws DriverException
    {
        open(ipAddr, 0);
    }


    /**
     *  Closes the connection.
     *
     *  This method isn't synchronized in order to allow a connection to be
     *  closed while another thread is waiting for a read to complete.  But
     *  some synchronization does need to be implemented to avoid race
     *  conditions leading to null pointer exceptions..
     *
     *  @throws  DriverException
     */
    public void close() throws DriverException
    {
        checkOpen();
        sock.close();
        sock = null;
    }


    /**
     *  Sets the board type.
     *
     *  @throws  DriverException
     */
    private void setBoardType() throws DriverException
    {
        int addr = ipAddr[3];
        if (addr == ADDR_REB_PS_PROTO_0 || addr == ADDR_REB_PS_PROTO_1) {
            boardType = BoardType.REB_PS_PROTO;
        }
        else if (addr == ADDR_REB_PS_UPDATE) {
            boardType = BoardType.REB_PS_UPDATE;
        }
        else if (addr >= ADDR_REB_PS_PROD_MIN && addr <= ADDR_REB_PS_PROD_MAX
                   || addr == ADDR_REB_PS_PROD_EXTRA) {
            boardType = BoardType.REB_PS_PROD;
        }
        else if (addr == ADDR_PDU_5V) {
            boardType = BoardType.PDU_5V;
        }
        else if (addr == ADDR_PDU_24V_DIRTY) {
            boardType = BoardType.PDU_24V_DIRTY;
        }
        else if (addr == ADDR_PDU_24V_CLEAN) {
            boardType = BoardType.PDU_24V_CLEAN;
        }
        else if (addr == ADDR_PDU_48V) {
            boardType = BoardType.PDU_48V;
        }
        else if (addr == ADDR_BFR) {
            boardType = BoardType.BFR;
        }
        else if (addr == ADDR_ION_PUMP) {
            boardType = BoardType.ION_PUMP;
        }
        else if (addr == ADDR_HEATER) {
            boardType = BoardType.HEATER;
        }
        if (validBoardTypes != null && !validBoardTypes.contains(boardType)) {
            throw new DriverException("Invalid board type: " + boardType);
        }
    }


    /**
     *  Gets the board type.
     *
     *  @return  The board type
     */
    public BoardType getBoardType()
    {
        return boardType;
    }


    /**
     *  Sets the debug state.
     *
     *  @param  on  The debug on state, true or false
     */
    public void setDebug(Boolean on)
    {
        debug = on;
    }


    /**
     *  Gets the IP address.
     *
     *  @return  The IP address byte array
     */
    public byte[] getIpAddress()
    {
        return ipAddr;
    }


    /**
     *  Gets whether simulated
     *
     *  @return  Whether simulated
     */
    public boolean isSimulated()
    {
        return simulated;
    }


    /**
     *  Writes a register.
     *
     *  @param  addr   The register address
     *  @param  value  The value to write
     *  @throws  DriverException
     */
    public void writeReg(int addr, int value) throws DriverException
    {
        writeRegs(addr, new int[]{value});
    }


    /**
     *  Writes registers.
     *
     *  @param  addr   The first register address
     *  @param  value  The array of values to write.  If it contains more than the allowed
     *                 maximum number (1024 for SRP V3, 64 otherwise) of values, only the
     *                 allowed number are written.
     *  @throws  DriverException
     */
    public synchronized void writeRegs(int addr, int[] value) throws DriverException
    {
        checkOpen();
        int count = Math.min(value.length, maxRegs);
        if (simulated) {
            simWriteRegs(addr, value, count);
        }
        else if (srpVersion == 3) {
            writeRegs3(addr, value, count);
        }
        else {
            writeRegs0(addr, value, count);
        }
    }


    /**
     *  Writes registers using new (version 3) SRP protocol.
     *
     *  @param  addr   The first register address
     *  @param  value  The array of values to write
     *  @param  count  The number of values to write
     *  @throws  DriverException
     */
    private void writeRegs3(int addr, int[] value, int count) throws DriverException
    {
        Convert.intToBytes(4 * addr, outBuff, SRP3_OFF_ADDR_LO);
        Convert.intToBytes(0, outBuff, SRP3_OFF_ADDR_HI);
        Convert.intToBytes(4 * count - 1, outBuff, SRP3_OFF_SIZE);
        for (int j = 0; j < count; j++) {
            Convert.intToBytes(value[j], outBuff, SRP3_OFF_DATA + 4 * j);
        }
        outBuff[SRP3_OFF_VERSION] = 3;
        outBuff[SRP3_OFF_OPCODE] = SRP3_OPC_NP_WRITE;
        outBuff[SRP3_OFF_TIMEOUT] = 10;
        for (int j = 0; j < 2; j++) {
            try {
                send(SRP3_PKT_LENG + 4 * count);
                receive();
                break;
            }
            catch (DriverTimeoutException e) {
                if (j == 1) {
                    throw e;
                }
            }
        }
    }


    /**
     *  Writes registers using old SRP protocol.
     *
     *  @param  addr   The first register address
     *  @param  value  The array of values to write.
     *  @param  count  The number of registers to write
     *  @throws  DriverException
     */
    private void writeRegs0(int addr, int[] value, int count) throws DriverException
    {
        Convert.intToBytesBE(addr, outBuff, SRP_OFF_ADDR);
        for (int j = 0; j < count; j++) {
            Convert.intToBytesBE(value[j], outBuff, SRP_OFF_DATA + 4 * j);
        }
        outBuff[SRP_OFF_OC] = SRP_OC_WRITE;
        for (int j = 0; j < 2; j++) {
            try {
                send(SRP_PKT_LENG + 4 * (count - 1));
                receive();
                break;
            }
            catch (DriverTimeoutException e) {
                if (j == 1) {
                    throw e;
                }
            }
        }
    }


    /**
     *  Reads a register.
     *
     *  @param  addr  The register address
     *  @return  The read value
     *  @throws  DriverException
     */
    public int readReg(int addr) throws DriverException
    {
        return readRegs(addr, 1)[0];
    }


    /**
     *  Reads registers.
     *
     *  @param  addr   The first register address
     *  @param  count  The number of registers to read.  If greater than the allowed
     *                 maximum number (1024 for SRP V3, 64 otherwise) of registers, only the
     *                 allowed number are read.
     *  @return  The array of read values
     *  @throws  DriverException
     */
    public synchronized int[] readRegs(int addr, int count) throws DriverException
    {
        checkOpen();
        if (count <= 0) {
            return new int[0];
        }
        count = Math.min(count, maxRegs);
        if (simulated) {
            return simReadRegs(addr, count);
        }
        else if (srpVersion == 3) {
            return readRegs3(addr, count);
        }
        else {
            return readRegs0(addr, count);
        }
    }


    /**
     *  Reads registers using new (version 3) SRP protocol.
     *
     *  @param  addr   The first register address
     *  @param  count  The number of registers to read.  If greater than
     *                 MAX_REGS (currently 64), is set to MAX_REGS.
     *  @return  The array of read values
     *  @throws  DriverException
     */
    private int[] readRegs3(int addr, int count) throws DriverException
    {
        outBuff[SRP3_OFF_VERSION] = 3;
        outBuff[SRP3_OFF_OPCODE] = SRP3_OPC_NP_READ;
        outBuff[SRP3_OFF_TIMEOUT] = 10;
        Convert.intToBytes(4 * addr, outBuff, SRP3_OFF_ADDR_LO);
        Convert.intToBytes(0, outBuff, SRP3_OFF_ADDR_HI);
        Convert.intToBytes(4 * count - 1, outBuff, SRP3_OFF_SIZE);
        for (int j = 0; j < 2; j++) {
            try {
                send(SRP3_PKT_LENG);
                receive();
                break;
            }
            catch (DriverTimeoutException e) {
                if (j == 1) {
                    throw e;
                }
            }
        }
        int nRead = (inPkt.getLength() - SRP3_PKT_LENG) / 4 - 1;
        int[] data = new int[nRead];
        for (int j = 0; j < nRead; j++) {
            data[j] = Convert.bytesToInt(inBuff, SRP3_OFF_DATA + 4 * j);
        }
        return data;
    }


    /**
     *  Reads registers using old SRP protocol.
     *
     *  @param  addr   The first register address
     *  @param  count  The number of registers to read.  If greater than
     *                 MAX_REGS (currently 64), is set to MAX_REGS.
     *  @return  The array of read values
     *  @throws  DriverException
     */
    private int[] readRegs0(int addr, int count) throws DriverException
    {
        Convert.intToBytesBE(addr, outBuff, SRP_OFF_ADDR);
        Convert.intToBytesBE(count - 1, outBuff, SRP_OFF_DATA);
        outBuff[SRP_OFF_OC] = SRP_OC_READ;
        for (int j = 0; j < 2; j++) {
            try {
                send(SRP_PKT_LENG);
                receive();
                break;
            }
            catch (DriverTimeoutException e) {
                if (j == 1) {
                    throw e;
                }
            }
        }
        int nRead = (inPkt.getLength() - SRP_PKT_LENG) / 4 + 1;
        int[] data = new int[nRead];
        for (int j = 0; j < nRead; j++) {
            data[j] = Convert.bytesToIntBE(inBuff, SRP_OFF_DATA + 4 * j);
        }
        return data;
    }


    /**
     *  Updates a register.
     *
     *  @param  addr   The register address
     *  @param  mask   The mask of bits to be updated
     *  @param  value  The value to use for updating
     *  @return  Previous register value
     *  @throws  DriverException
     */
    public synchronized int updateReg(int addr, int mask, int value) throws DriverException
    {
        int prevValue = readReg(addr);
        writeReg(addr, (prevValue & ~mask) | (value & mask));
        return prevValue;
    }


    /**
     *  Checks whether a connection is open.
     *
     *  @throws  DriverException
     */
    private void checkOpen() throws DriverException
    {
        if (sock == null) {
            throw new DriverException("Connection is not open");
        }
    }


    /**
     *  Sends the "out" packet.
     *
     *  @throws  DriverException
     */
    private void send(int leng) throws DriverException
    {
        seqno += 4;
        Convert.intToBytes(seqno, outBuff, srpVersion == 3 ? SRP3_OFF_TID : SRP_OFF_HEADER);
        if (srpVersion != 3) {
            Convert.intToBytes(0, outBuff, leng - SRP_PKT_LENG + SRP_OFF_FOOTER);
        }
        showData("Sent:", leng, outBuff);
        outPkt.setLength(leng);
        try {
            sock.send(outPkt);
        }
        catch (IOException e) {
            throw new DriverException(e);
        }
    }


    /**
     *  Receives the "in" packet.
     *
     *  @throws  DriverException
     */
    private void receive() throws DriverException
    {
        while (true) {
            inPkt.setLength(inBuff.length);
            try {
                sock.receive(inPkt);
                int rSeqno = Convert.bytesToInt(inBuff, srpVersion == 3 ? SRP3_OFF_TID : SRP_OFF_HEADER);
                if (seqno == rSeqno) break;
                nSeqErr++;
            }
            catch (SocketTimeoutException e) {
                nTimeout++;
                throw new DriverTimeoutException();
            }
            catch (IOException e) {
                throw new DriverException(e);
            }
        }
        int leng = inPkt.getLength();
        showData("Rcvd:", leng, inBuff);
        int status;
        String message = null;
        if (srpVersion == 3) {
            status = Convert.bytesToInt(inBuff, leng - 4);
            if ((status & SRP3_STS_MEMORY) != 0) {
                message = "Invalid register address";
            }
            if ((status & SRP3_STS_TIMEOUT) != 0) {
                message = "Register access timeout";
            }
            if ((status & SRP3_STS_EOFE) != 0) {
                message = "End of frame with error";
            }
            if ((status & SRP3_STS_FRAMING) != 0) {
                message = "Framing error";
            }
            if ((status & SRP3_STS_VERSION) != 0) {
                message = "Version mismatch";
            }
            if ((status & SRP3_STS_REQUEST) != 0) {
                message = "Invalid request";
            }
        }
        else {
            status = Convert.bytesToIntBE(inBuff, leng + SRP_OFF_FOOTER - SRP_PKT_LENG);
            if ((status & SRP_STS_TIMEOUT) != 0) {
                message = "Register access timeout";
            }
            if ((status & SRP_STS_ERROR) != 0) {
                message = "Register access error";
            }
        }
        if (message != null) {
            throw new DriverException(message);
        }
    }


    /**
     *  Display data buffer
     */
    private void showData(String title, int leng, byte[] data)
    {
        if (!debug) return;
        String blanks = null;
        for (int j = 0; j < leng; j++) {
            if ((j % 32) == 0) {
                if (j == 0) {
                    System.out.print(title);
                }
                else {
                    if (blanks == null) {
                        blanks = "\n                    ".substring(0, title.length() + 1);
                    }
                    System.out.print(blanks);
                }
            }
            if ((j % 4) == 0) {
                System.out.print(" ");
            }
            System.out.format("%02x", data[j] & 0xff);
        }
        System.out.println();
    }


    /**
     *  Gets the sequence error count.
     *
     *  @return  The number of sequence errors
     */
    public int getNumSeqErr()
    {
        return nSeqErr;
    }


    /**
     *  Gets the socket timeout count.
     *
     *  @return  The number of socket timeouts
     */
    public int getNumTimeout()
    {
        return nTimeout;
    }


    /**
     *  Initializes the simulation.
     */
    protected void simInitialize()
    {
        clearSimRegMap();
    }

    
    /**
     *  Writes simulated registers.
     *
     *  @param  addr   The first register address
     *  @param  value  The array of values to write.
     *  @param  count  The number of values to write.
     */
    protected void simWriteRegs(int addr, int[] value, int count)
    {
        for (int j = 0; j < count; j++, addr++) {
            putSimRegMap(addr, value[j]);
        }
    }


    /**
     *  Reads simulated registers.
     *
     *  @param  addr   The first register address
     *  @param  count  The number of registers to read.
     *  @return  The array of read values
     */
    protected int[] simReadRegs(int addr, int count)
    {
        int[] data = new int[count];
        for (int j = 0; j < count; j++) {
            data[j] = getSimRegMap(addr + j);
        }
        return data;
    }


    /**
     *  Clears the simulated register map.
     */
    protected final void clearSimRegMap()
    {
        simRegMap.clear();
    }

    
    /**
     *  Puts to the simulated register map.
     *
     *  @param  addr   The register address
     *  @param  value  The array of values to write.
     */
    protected final void putSimRegMap(int addr, int value)
    {
        simRegMap.put(addr, value);
    }


    /**
     *  Gets from the simulated register map.
     *
     *  @param  addr   The register address
     *  @return  The array of read values
     */
    protected final int getSimRegMap(int addr)
    {
        Integer value = simRegMap.get(addr);
        return value == null ? 0 : value;
    }
        
}
