package org.lsst.ccs.messaging.util;

import java.util.EnumMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.lsst.ccs.bus.definition.Bus;
import org.lsst.ccs.messaging.TransportStateException;

/**
 *
 * @author onoprien
 */
public class MultiQueueDispatcher extends AbstractDispatcher {

// -- Fields : -----------------------------------------------------------------
    
    private final AtomicLong taskID = new AtomicLong();
    private volatile boolean off = true;    
    private final EnumMap<Bus,Object> inExecutors;
    private final EnumMap<Bus,ExecutorService> outNormExecutors;
    private final ThreadPoolExecutor outOobCcExec;
    private final ThreadPoolExecutor oobExec;
    
// -- Life cycle : -------------------------------------------------------------
    
    public MultiQueueDispatcher(String... args) {
        super(args);
        
        // Executors :
        
        EnumMap<Bus,Integer> defThreads = new EnumMap<>(Bus.class);
        defThreads.put(Bus.COMMAND, 1);
        defThreads.put(Bus.LOG, 1);
        defThreads.put(Bus.STATUS, 10);
        inExecutors = new EnumMap<>(Bus.class);
        for (Bus bus : Bus.values()) {
            Integer threads = getIntegerArg("in_"+ bus.name().toLowerCase() +"_threads", args);
            int n = threads == null ? defThreads.get(bus) : threads;
            if (n == 1) {
                inExecutors.put(bus, Executors.newSingleThreadExecutor(new TFactory("MESSAGING_IN_"+ bus)));
            } else {
                inExecutors.put(bus, new KeyQueueExecutor("MESSAGING_IN_"+ bus, n));
            }
        }
        
        outNormExecutors = new EnumMap<>(Bus.class);
        for (Bus bus : Bus.values()) {
            outNormExecutors.put(bus, Executors.newSingleThreadExecutor(new TFactory("MESSAGING_OUT_"+ bus)));
        }
        
        outOobCcExec = new ThreadPoolExecutor(2, 2, 70L, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), new TFactory("MESSAGING_OUT_OOB_CC"));
        outOobCcExec.allowCoreThreadTimeOut(true);
                
        oobExec = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 70L, TimeUnit.SECONDS, new SynchronousQueue<>(), new TFactory("MESSAGING_OOB"));
    }

    @Override
    public void initialize() {
        super.initialize();
        off = false;
    }

    @Override
    public void shutdown() {
        off = true;
        
        oobExec.shutdown();
        outOobCcExec.shutdown();
        outNormExecutors.values().forEach(exec -> exec.shutdown());
        inExecutors.values().forEach(exec -> {
            if (exec instanceof KeyQueueExecutor) {
                ((KeyQueueExecutor) exec).shutdownNow();
            } else {
                ((ExecutorService) exec).shutdownNow();
            }
        });
        
        try {
            oobExec.awaitTermination(1, TimeUnit.MINUTES);
            outOobCcExec.awaitTermination(1, TimeUnit.MINUTES);
            for (ExecutorService exec : outNormExecutors.values()) {
                exec.awaitTermination(1, TimeUnit.MINUTES);
            }
            for (Object exec : inExecutors.values()) {
                if (exec instanceof KeyQueueExecutor) {
                    ((KeyQueueExecutor) exec).awaitTermination(1, TimeUnit.MINUTES);
                } else {
                    ((ExecutorService) exec).awaitTermination(1, TimeUnit.MINUTES);
                }
            }
        } catch (InterruptedException x) {
            Thread.currentThread().interrupt();
        }
        
        super.shutdown();
    }

    /** For testing only. */
    void drainAndShutdown() {
        off = true;
        
        oobExec.shutdown();
        outOobCcExec.shutdown();
        outNormExecutors.values().forEach(exec -> exec.shutdown());
        inExecutors.values().forEach(exec -> {
            if (exec instanceof KeyQueueExecutor) {
                ((KeyQueueExecutor) exec).shutdown();
            } else {
                ((ExecutorService) exec).shutdown();
            }
        });
        
        try {
            oobExec.awaitTermination(1, TimeUnit.MINUTES);
            outOobCcExec.awaitTermination(1, TimeUnit.MINUTES);
            for (ExecutorService exec : outNormExecutors.values()) {
                exec.awaitTermination(1, TimeUnit.MINUTES);
            }
            for (Object exec : inExecutors.values()) {
                if (exec instanceof KeyQueueExecutor) {
                    ((KeyQueueExecutor) exec).awaitTermination(1, TimeUnit.MINUTES);
                } else {
                    ((ExecutorService) exec).awaitTermination(1, TimeUnit.MINUTES);
                }
            }
        } catch (InterruptedException x) {
            Thread.currentThread().interrupt();
        }
        
        super.shutdown();
    }
    
// -- Task submission : --------------------------------------------------------

    /**
     * Submits a task to process an incoming message or disconnection notification.
     * This method returns immediately, without waiting for the task to finish execution.
     * <p>
     * If one or more agent names are given, the task if guaranteed to be executed after any previously submitted
     * tasks for the same bus and agent. If no agents are specified, the task is independent of other tasks,
     * subject to capacity controls imposed by this service. If the {@code agents} argument
     * is {@code null}, the task is independent of other tasks, not subject to capacity controls.
     * 
     * @param run Task to be executed.
     * @param bus Bus (LOG, STATUS, or COMMAND).
     * @param agents Names of affected agents (source or disconnected).  
     */
    @Override
    public void in(Runnable run, Bus bus, String... agents) {
        if (off) return;
        long id = taskID.getAndIncrement();
        long time = System.currentTimeMillis();
        Order order = agents == null ? Order.OOB : (agents.length == 0 ? Order.OOB_CC : Order.NORM);
        report(run, id, true, bus, order, time, Stage.START);
        Task task = new Task(id, run, true, bus, order, time);
        try {
            switch (order) {
                case NORM:
                case OOB_CC:
                    Object exec = inExecutors.get(bus);
                    if (exec instanceof KeyQueueExecutor) {
                        ((KeyQueueExecutor)exec).execute(task, agents);
                    } else {
                        ((ExecutorService)exec).execute(task);
                    }
                    break;
                case OOB:
                    oobExec.execute(task);
                    break;
            }
        } catch (RejectedExecutionException x) {
            throw new TransportStateException(x);
        } finally {
            report(run, id, true, bus, order, System.currentTimeMillis() - time, Stage.SUBMIT);
        }
    }

    /**
     * Submits a task to process an outgoing message.
     * This method returns immediately, without waiting for the task to finish execution.
     * If the service has been shut down, calling this method does nothing.
     * <p>
     * Tasks submitted with {@code outOfBand} equal to {@code false} for the same bus are guaranteed to be
     * processed in the order of submission. If {@code outOfBand} equals {@code true}, the task is executed 
     * independently of others, subject to capacity controls imposed by this service.
     * If {@code outOfBand} equals {@code null}, the task is independent, not subject to capacity controls.
     * 
     * @param run Task to be executed.
     * @param bus Bus (LOG, STATUS, or COMMAND).
     * @param order Order of execution and capacity control policies.
     */
    @Override
    public void out(Runnable run, Bus bus, Order order) {
        if (off) throw new TransportStateException();
        long id = taskID.getAndIncrement();
        long time = System.currentTimeMillis();
        report(run, id, false, bus, order, time, Stage.START);
        Task task = new Task(id, run, false, bus, order, time);
        try {
            switch (order) {
                case NORM:
                    outNormExecutors.get(bus).execute(task);
                    break;
                case OOB_CC:
                    outOobCcExec.execute(task);
                    break;
                case OOB:
                    oobExec.execute(task);
                    break;
            }
        } catch (RejectedExecutionException x) {
            throw new TransportStateException();
        } finally {
            report(run, id, false, bus, order, System.currentTimeMillis() - time, Stage.SUBMIT);
        }
    }
    
    
// -- Local classes : ----------------------------------------------------------
   
    private class Task implements Runnable {

        private final long id;
        private final Runnable runnable;
        private final boolean incoming;
        private final Bus bus;
        private final Order order;
        private long time = System.currentTimeMillis();

        Task(long id, Runnable runnable, boolean incoming, Bus bus, Order order, long time) {
            this.id = id;
            this.runnable = runnable;
            this.incoming = incoming;
            this.bus = bus;
            this.order = order;
            this.time = time;
        }

        @Override
        public void run() {
            try {
                long current = System.currentTimeMillis();
                report(runnable, id, incoming, bus, order, current - time, Stage.WAIT);
                time = current;
                runnable.run();
            } catch (Exception x) {
                getLogger().warn("Error sending message", x);
            } finally {
                try {
                    report(runnable, id, incoming, bus, order, System.currentTimeMillis() - time, Stage.RUN);
                } catch (Exception x) {
                }
            }
        }

    }
    
    private class TFactory implements ThreadFactory {
        
        private final ThreadFactory delegate = Executors.defaultThreadFactory();
        private final String name;
        
        TFactory(String name) {
            this.name = name;
        }

        @Override
        public Thread newThread(Runnable r) {
            Thread thread = delegate.newThread(r);
            thread.setDaemon(true);
            thread.setName(name);
            thread.setUncaughtExceptionHandler((t, x) -> getLogger().warn("Exception thrown from messaging executor "+ name, x));
            return thread;
        }

    }
        
}
