package org.lsst.ccs.subsystems.console.jython;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.net.util.SubnetUtils;
import org.lsst.ccs.BusMaster;
import org.lsst.ccs.bus.data.AgentInfo;
import org.lsst.ccs.command.annotations.Command;
import org.lsst.ccs.framework.HasLifecycle;
import org.python.core.PyObject;
import org.python.core.PySystemState;
import org.python.util.PythonInterpreter;
import org.lsst.ccs.scripting.CCS;
import org.lsst.ccs.scripting.jython.JythonScriptExecutorUtils;
import org.lsst.ccs.startup.HasCommandLineOptions;
import org.lsst.ccs.utilities.logging.Logger;

/**
 *
 * @author LSST CCS Team
 */
public class JythonInterpreterConsole extends BusMaster implements HasCommandLineOptions, HasLifecycle {

    private static final String PORT_OPTION = "port";
    private final HashMap<Long, InternalJythonInterpreterThread> openSocketConnections = new HashMap<>();
    private final int maxNumberSocketConnections = Integer.valueOf(System.getProperty("org.lsst.ccs.subsystem.console.jython.interpreter.max.sockets", "20"));
    private final Logger logger = Logger.getLogger("org.lsst.ccs.console.jython");
    private boolean shuttingDown = false;
    private Thread socketThread;

    private ServerSocket serverSocket;
    private int portNumber = -1;
    private final List<SubnetUtils> allowedIps;

    private final Options commandLineOptions;

    public JythonInterpreterConsole() throws IOException {
        this(4444);
    }
    JythonInterpreterConsole(int portNumber) throws IOException {
        this("jython-server",portNumber);
    }
    JythonInterpreterConsole(String consoleName, int portNumber) {
        super(consoleName, AgentInfo.AgentType.SERVICE);
        
        CCS.setShareLocksAcrossThreads(false);

        commandLineOptions = new Options();
        commandLineOptions.addOption("p", PORT_OPTION, true, "Port to run the interpreter on.");
        
        Properties props = System.getProperties();
        PythonInterpreter.initialize(props, props, new String[] {""});

        if ( this.portNumber < 0 ) {
            this.portNumber = portNumber;            
        }
        allowedIps = JythonConsoleNetworkUtilities.loadNetworkAccessInformation();
    }

    @Override
    public void start() {
        // Can't throw an exception in an overrided method.
        try {
            serverSocket = new ServerSocket(this.portNumber);
        } catch (IOException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        logger.info("Starting JythonInterpreter console "+getName()+" on port "+portNumber);
        startListeningOnSocket();
    }

    @Override
    public void postShutdown() {
        shuttingDown = true;
        boolean canShutdown = true;
        synchronized (openSocketConnections) {
            for (InternalJythonInterpreterThread thread : openSocketConnections.values()) {
                if (thread.getInterpreter().isExecuting()) {
                    canShutdown = false;
                    break;
                }
            }
            if (canShutdown) {
                for (InternalJythonInterpreterThread thread : Arrays.asList(openSocketConnections.values().toArray(new InternalJythonInterpreterThread[0]))) {
                    thread.interrupt();
                }
            }
        }

        if (!canShutdown) {
            throw new RuntimeException("Cannot shutdown JythonInterpreterConsole. There is at least an active thread. It needs to be aborted first.");
        }
        shuttingDown = false;
    }


    /*
     * Refuse the Socket connection.
     */
    private void refuseSocketConnection(PrintWriter out, Socket socket, String message) throws IOException {
        logger.warn(message);
        out.println("ConnectionRefused");
        out.flush();
        out.close();
        socket.close();
    }

    private void startListeningOnSocket() {
        //Start the daemon thread listening on sockets.
        //This thread is daemon so that the console can be shutdown.        
        socketThread = new Thread() {
            Socket clientSocket;

            @Override
            public void run() {
                //We accept either connections from localhost or from IPs defined in access.list
                //The maximum number of connections is defined via System property org.lsst.ccs.subsystem.console.jython.interpreter.max.sockets
                try {
                    while ((clientSocket = serverSocket.accept()) != null) {
                        PrintWriter out = new PrintWriter(clientSocket.getOutputStream(), true);
                        InetAddress clientAddress = clientSocket.getInetAddress();
                        if (shuttingDown) {
                            refuseSocketConnection(out, clientSocket, "The Jython Console is shutting down");
                        } else {
                            synchronized (openSocketConnections) {
                                if (JythonConsoleNetworkUtilities.isInetAddressLocalhost(clientAddress) || JythonConsoleNetworkUtilities.isIpAddressAllowed(clientAddress.getHostAddress(), allowedIps)) {
                                    if (openSocketConnections.size() >= maxNumberSocketConnections) {
                                        refuseSocketConnection(out, clientSocket, "Socket connection from " + clientAddress.getHostAddress() + " refused as the maximum number of allowed connections (" + maxNumberSocketConnections + ") has already been reached.");
                                    } else {
                                        InternalJythonInterpreter interpreter = new InternalJythonInterpreter(clientSocket);
                                        InternalJythonInterpreterThread interpreterThread = new InternalJythonInterpreterThread(interpreter);
                                        out.println("ConnectionOK " + interpreterThread.getId());
                                        out.flush();
                                        openSocketConnections.put(interpreterThread.getId(), interpreterThread);
                                        interpreterThread.start();
                                    }
                                } else {
                                    refuseSocketConnection(out, clientSocket, "Socket connection from " + clientAddress.getHostAddress() + " closed as it's neither localhost nor in the list of accepted IP addresses.");
                                }
                            }
                        }
                    }
                } catch (IOException ioe) {
                    throw new RuntimeException("Problems while establishing a socket connection", ioe);
                }
            }
        };
        socketThread.setDaemon(true);
        socketThread.start();
        while (socketThread.getState() != Thread.State.RUNNABLE) {
        };
    }

    @org.lsst.ccs.command.annotations.Command(description = "Get interpreter threads", type = Command.CommandType.QUERY)
    public String getInterpreterThreads() {
        String result = "";
        for (Long id : openSocketConnections.keySet()) {
            result += " " + id + " " + openSocketConnections.get(id).getThreadName() + "\n";
        }
        return result;
    }

    @org.lsst.ccs.command.annotations.Command(description = "Abort an interpreter thread by Name", type = Command.CommandType.ACTION)
    public void abortInterpreterThread(String name
    ) {
        InternalJythonInterpreterThread interpreterThread = null;
        synchronized (openSocketConnections) {
            for (InternalJythonInterpreterThread interpreter : openSocketConnections.values()) {
                if (interpreter.getThreadName().equals(name)) {
                    interpreterThread = interpreter;
                    break;
                }
            }
        }
        if (interpreterThread != null) {
            interpreterThread.interrupt();
        } else {
            throw new RuntimeException("Cannot find Thread with name " + name);
        }
    }

    private void interpreterThreadTerminated(Long id) {
        synchronized (openSocketConnections) {
            openSocketConnections.remove(id);
        }
    }

    class InternalJythonInterpreterThread extends Thread {

        private final InternalJythonInterpreter interpreter;
        private String name = getName();

        InternalJythonInterpreterThread(InternalJythonInterpreter interpreter) {
            super(interpreter);
            this.interpreter = interpreter;
        }

        private InternalJythonInterpreter getInterpreter() {
            return interpreter;
        }

        private void setThreadName(String name) {
            this.name = name;
        }

        String getThreadName() {
            return name;
        }

        @Override
        public void interrupt() {
            interpreter.interrupt();
            super.interrupt(); //To change body of generated methods, choose Tools | Templates.
        }
    }

    class InternalJythonInterpreter implements Runnable {

        private final Socket clientSocket;
        private final BufferedReader in;
        private final PrintWriter out;
        private boolean executing = false;
        private Thread executionThread = null;
        private final PythonInterpreter pyInterpreter;

        public InternalJythonInterpreter(Socket socket) throws IOException {
            this.clientSocket = socket;
            out = new PrintWriter(socket.getOutputStream(), true);
            in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
            //PythonInterpreters cannot share the same PySystemState if they have to have separate
            //output streams
            PySystemState state = new PySystemState();

            pyInterpreter = new PythonInterpreter(null, state);
            pyInterpreter.setOut(out);
            pyInterpreter.setErr(out);
        }

        private PythonInterpreter getPythonInterpreter() {
            return pyInterpreter;
        }

        private void interrupt() {
            if (executionThread != null) {
                executionThread.interrupt();
            }
        }

        private PrintWriter getOutputWriter() {
            return out;
        }

        private boolean isExecuting() {
            return executing;
        }

        private void setIsExecuting(boolean executing) {
            this.executing = executing;
        }

        @Override
        public void run() {

            String inputLine;
            String contentId = null;

            StringBuilder executionBuffer = new StringBuilder();

            try {
                boolean isValidBuffer = false;
                while ((inputLine = in.readLine()) != null) {
                    if (inputLine.startsWith("startContent:")) {
                        isValidBuffer = true;
                        contentId = inputLine.replace("startContent:", "");
                        executionBuffer.setLength(0);
                    } else if (inputLine.startsWith("endContent:")) {
                        // This is where we execute the content sent to the socket.
                        String bufferContent = executionBuffer.toString();

                        if (bufferContent.startsWith("abortInterpreter")) {
                            String name = bufferContent.replace("abortInterpreter", "").trim();
                            if (name.isEmpty()) {
                                name = openSocketConnections.get(Thread.currentThread().getId()).getThreadName();
                            }
                            JythonInterpreterConsole.this.abortInterpreterThread(name);
                            out.println("doneExecution:" + contentId);
                            out.flush();
                        } else if (bufferContent.startsWith("initializeInterpreter")) {
                            String name = bufferContent.replace("initializeInterpreter", "").trim();
                            openSocketConnections.get(Thread.currentThread().getId()).setThreadName(name);
                            out.println("doneExecution:" + contentId);
                            out.flush();
                        } else {
                            if (!executing) {
                                if (!shuttingDown) {
                                    executing = true;
                                    executionThread = new JythonProcessingThread(this, bufferContent, contentId);
                                    executionThread.start();
                                } else {
                                    RuntimeException e = new RuntimeException("The console is shutting down. It cannot execute any content");
                                    e.printStackTrace(out);
                                    logger.warn(e.getMessage());
                                }

                            } else {
                                RuntimeException e = new RuntimeException("There is already an execution for this socket connection. Only one execution at a time is allowed");
                                e.printStackTrace(out);
                                logger.warn(e.getMessage());
                            }
                        }
                        isValidBuffer = false;
                    } else {
                        if (isValidBuffer) {
                            executionBuffer.append(inputLine).append("\n");
                        } else {
                            throw new RuntimeException("Invalid Socket Protocol. The content was not initialized.");
                        }
                    }

                }

            } catch (IOException ex) {
                ex.printStackTrace();
            }

            // When the socket closes, the above in.readLine() returns null and we get here
            // This is where we clean things up and close the socket. 
            pyInterpreter.cleanup();
            JythonInterpreterConsole.this.interpreterThreadTerminated(Thread.currentThread().getId());
            try {
                clientSocket.close();
            } catch (IOException ioe) {

            }
        }
    }

    private class JythonProcessingThread extends Thread {

        private final InternalJythonInterpreter interpreter;
        private final String contentId, bufferContent;

        JythonProcessingThread(InternalJythonInterpreter interpreter, String bufferContent, String contentId) {
            this.interpreter = interpreter;
            this.contentId = contentId;
            this.bufferContent = bufferContent;
            logger.info("Executing contentId: "+contentId);
            logger.info(bufferContent);
        }

        @Override
        public void run() {
            try {
                PyObject code = interpreter.getPythonInterpreter().compile(bufferContent);
                interpreter.getPythonInterpreter().exec(code);
            } catch (Exception executionException) {
                executionException.printStackTrace(interpreter.getOutputWriter());
            } finally {
                interpreter.getOutputWriter().println("doneExecution:" + contentId);
                interpreter.setIsExecuting(false);
            }
            logger.info("Done executing contentId: "+contentId);            
            CCS.cleanUp();            
        }

    }

    @Override
    public void printHelp() {
        HelpFormatter formatter = new HelpFormatter();
        formatter.printHelp(100, "CCS Jython Interpreter Console", "", commandLineOptions, "", true);
    }    

    @Override
    public void processCommandLineOptions(String[] args) throws ParseException {


        CommandLineParser parser = new BasicParser();
        CommandLine line = parser.parse(commandLineOptions, args, true);

        JythonScriptExecutorUtils.loadPythonSystemProperties();

        if (line.hasOption(PORT_OPTION)) {
            this.portNumber = Integer.parseInt(line.getOptionValue(PORT_OPTION));
        }

    }
}
